diff --git a/CMakeLists.txt b/CMakeLists.txt index f5adb542a..a8cb7701b 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,10 @@ if(NOT DEFINED VENDOR) set(VENDOR "FreeRDP" CACHE STRING "FreeRDP package vendor") endif() +if(NOT DEFINED FREERDP_VENDOR) + set(FREERDP_VENDOR 1) +endif() + set(CMAKE_COLOR_MAKEFILE ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) @@ -237,8 +241,8 @@ if(IOS) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -isysroot ${CMAKE_IOS_SDK_ROOT} -g") endif() -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWINPR_EXPORTS") -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DFREERDP_EXPORTS") +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWINPR_EXPORTS -DFREERDP_EXPORTS") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWINPR_EXPORTS -DFREERDP_EXPORTS") # Include files if(NOT IOS) diff --git a/channels/audin/client/winmm/audin_winmm.c b/channels/audin/client/winmm/audin_winmm.c index 8881c04b0..d0ffb5553 100644 --- a/channels/audin/client/winmm/audin_winmm.c +++ b/channels/audin/client/winmm/audin_winmm.c @@ -138,8 +138,8 @@ static DWORD audin_winmm_thread_func(void* arg) static void audin_winmm_free(IAudinDevice* device) { + UINT32 i; AudinWinmmDevice* winmm = (AudinWinmmDevice*) device; - int i; for (i = 0; i < winmm->cFormats; i++) { @@ -172,8 +172,8 @@ static void audin_winmm_close(IAudinDevice* device) static void audin_winmm_set_format(IAudinDevice* device, audinFormat* format, UINT32 FramesPerPacket) { + UINT32 i; AudinWinmmDevice* winmm = (AudinWinmmDevice*) device; - int i; winmm->frames_per_packet = FramesPerPacket; diff --git a/channels/drive/client/drive_file.c b/channels/drive/client/drive_file.c index eb430fec6..d7b3852fc 100644 --- a/channels/drive/client/drive_file.c +++ b/channels/drive/client/drive_file.c @@ -482,8 +482,8 @@ BOOL drive_file_set_information(DRIVE_FILE* file, UINT32 FsInformationClass, UIN /* http://msdn.microsoft.com/en-us/library/cc232067.aspx */ case FileAllocationInformation: /* http://msdn.microsoft.com/en-us/library/cc232076.aspx */ -#ifndef _WIN32 Stream_Read_UINT64(input, size); +#ifndef _WIN32 if (ftruncate(file->fd, size) != 0) return FALSE; #endif diff --git a/channels/rdpdr/client/rdpdr_main.c b/channels/rdpdr/client/rdpdr_main.c index 2c9fe5350..5759b5154 100644 --- a/channels/rdpdr/client/rdpdr_main.c +++ b/channels/rdpdr/client/rdpdr_main.c @@ -65,8 +65,8 @@ static void rdpdr_send_device_list_announce_request(rdpdrPlugin* rdpdr, BOOL use static void rdpdr_send_device_list_remove_request(rdpdrPlugin* rdpdr, UINT32 count, UINT32 ids[]) { + UINT32 i; wStream* s; - int i; s = Stream_New(NULL, 256); @@ -307,6 +307,7 @@ static char* get_word(char* str, unsigned int* offset) { char* p; char* tmp; + char* word; int wlen; if (*offset >= strlen(str)) @@ -325,7 +326,15 @@ static char* get_word(char* str, unsigned int* offset) while (*(str + *offset) == ' ') (*offset)++; - return strndup(p, wlen); + word = malloc(wlen + 1); + + if (word != NULL) + { + CopyMemory(word, p, wlen); + word[wlen] = '\0'; + } + + return word; } static void handle_hotplug(rdpdrPlugin* rdpdr) diff --git a/channels/smartcard/client/smartcard_main.c b/channels/smartcard/client/smartcard_main.c index fe01b9377..97d9161db 100644 --- a/channels/smartcard/client/smartcard_main.c +++ b/channels/smartcard/client/smartcard_main.c @@ -35,6 +35,85 @@ #include "smartcard_main.h" +void* smartcard_context_thread(SMARTCARD_CONTEXT* pContext) +{ + DWORD nCount; + DWORD status; + HANDLE hEvents[2]; + wMessage message; + SMARTCARD_DEVICE* smartcard; + SMARTCARD_OPERATION* operation; + + smartcard = pContext->smartcard; + + nCount = 0; + hEvents[nCount++] = MessageQueue_Event(pContext->IrpQueue); + + while (1) + { + status = WaitForMultipleObjects(nCount, hEvents, FALSE, INFINITE); + + if (WaitForSingleObject(MessageQueue_Event(pContext->IrpQueue), 0) == WAIT_OBJECT_0) + { + if (!MessageQueue_Peek(pContext->IrpQueue, &message, TRUE)) + break; + + if (message.id == WMQ_QUIT) + break; + + operation = (SMARTCARD_OPERATION*) message.wParam; + + if (operation) + { + status = smartcard_irp_device_control_call(smartcard, operation); + + Queue_Enqueue(smartcard->CompletedIrpQueue, (void*) operation->irp); + + free(operation); + } + } + } + + ExitThread(0); + return NULL; +} + +SMARTCARD_CONTEXT* smartcard_context_new(SMARTCARD_DEVICE* smartcard, SCARDCONTEXT hContext) +{ + SMARTCARD_CONTEXT* pContext; + + pContext = (SMARTCARD_CONTEXT*) calloc(1, sizeof(SMARTCARD_CONTEXT)); + + if (!pContext) + return pContext; + + pContext->smartcard = smartcard; + + pContext->hContext = hContext; + + pContext->IrpQueue = MessageQueue_New(NULL); + + pContext->thread = CreateThread(NULL, 0, + (LPTHREAD_START_ROUTINE) smartcard_context_thread, + pContext, 0, NULL); + + return pContext; +} + +void smartcard_context_free(SMARTCARD_CONTEXT* pContext) +{ + if (!pContext) + return; + + MessageQueue_PostQuit(pContext->IrpQueue, 0); + WaitForSingleObject(pContext->thread, INFINITE); + CloseHandle(pContext->thread); + + MessageQueue_Free(pContext->IrpQueue); + + free(pContext); +} + static void smartcard_free(DEVICE* device) { SMARTCARD_DEVICE* smartcard = (SMARTCARD_DEVICE*) device; @@ -51,6 +130,12 @@ static void smartcard_free(DEVICE* device) ListDictionary_Free(smartcard->rgOutstandingMessages); Queue_Free(smartcard->CompletedIrpQueue); + if (smartcard->StartedEvent) + { + SCardReleaseStartedEvent(); + smartcard->StartedEvent = NULL; + } + free(device); } @@ -65,7 +150,7 @@ static void smartcard_init(DEVICE* device) int keyCount; ULONG_PTR* pKeys; SCARDCONTEXT hContext; - + SMARTCARD_CONTEXT* pContext; SMARTCARD_DEVICE* smartcard = (SMARTCARD_DEVICE*) device; /** @@ -86,7 +171,12 @@ static void smartcard_init(DEVICE* device) for (index = 0; index < keyCount; index++) { - hContext = (SCARDCONTEXT) ListDictionary_GetItemValue(smartcard->rgSCardContextList, (void*) pKeys[index]); + pContext = (SMARTCARD_CONTEXT*) ListDictionary_GetItemValue(smartcard->rgSCardContextList, (void*) pKeys[index]); + + if (!pContext) + continue; + + hContext = pContext->hContext; if (SCardIsValidContext(hContext)) { @@ -108,8 +198,12 @@ static void smartcard_init(DEVICE* device) for (index = 0; index < keyCount; index++) { - hContext = (SCARDCONTEXT) ListDictionary_GetItemValue(smartcard->rgSCardContextList, (void*) pKeys[index]); - ListDictionary_Remove(smartcard->rgSCardContextList, (void*) pKeys[index]); + pContext = (SMARTCARD_CONTEXT*) ListDictionary_Remove(smartcard->rgSCardContextList, (void*) pKeys[index]); + + if (!pContext) + continue; + + hContext = pContext->hContext; if (SCardIsValidContext(hContext)) { @@ -131,16 +225,21 @@ void smartcard_complete_irp(SMARTCARD_DEVICE* smartcard, IRP* irp) irp->Complete(irp); } -void* smartcard_process_irp_worker_proc(IRP* irp) +void* smartcard_process_irp_worker_proc(SMARTCARD_OPERATION* operation) { + IRP* irp; + UINT32 status; SMARTCARD_DEVICE* smartcard; + irp = operation->irp; smartcard = (SMARTCARD_DEVICE*) irp->device; - smartcard_irp_device_control(smartcard, irp); + status = smartcard_irp_device_control_call(smartcard, operation); Queue_Enqueue(smartcard->CompletedIrpQueue, (void*) irp); + free(operation); + ExitThread(0); return NULL; } @@ -153,19 +252,34 @@ void* smartcard_process_irp_worker_proc(IRP* irp) void smartcard_process_irp(SMARTCARD_DEVICE* smartcard, IRP* irp) { void* key; + UINT32 status; BOOL asyncIrp = FALSE; - UINT32 ioControlCode = 0; + SMARTCARD_CONTEXT* pContext = NULL; + SMARTCARD_OPERATION* operation = NULL; key = (void*) (size_t) irp->CompletionId; ListDictionary_Add(smartcard->rgOutstandingMessages, key, irp); if (irp->MajorFunction == IRP_MJ_DEVICE_CONTROL) { - smartcard_irp_device_control_peek_io_control_code(smartcard, irp, &ioControlCode); + operation = (SMARTCARD_OPERATION*) calloc(1, sizeof(SMARTCARD_OPERATION)); - if (!ioControlCode) + if (!operation) return; + operation->irp = irp; + + status = smartcard_irp_device_control_decode(smartcard, operation); + + if (status != SCARD_S_SUCCESS) + { + irp->IoStatus = STATUS_UNSUCCESSFUL; + + Queue_Enqueue(smartcard->CompletedIrpQueue, (void*) irp); + + return; + } + asyncIrp = TRUE; /** @@ -174,7 +288,7 @@ void smartcard_process_irp(SMARTCARD_DEVICE* smartcard, IRP* irp) * those expected to return fast synchronously. */ - switch (ioControlCode) + switch (operation->ioControlCode) { case SCARD_IOCTL_ESTABLISHCONTEXT: case SCARD_IOCTL_RELEASECONTEXT: @@ -237,17 +351,23 @@ void smartcard_process_irp(SMARTCARD_DEVICE* smartcard, IRP* irp) break; } + pContext = ListDictionary_GetItemValue(smartcard->rgSCardContextList, (void*) operation->hContext); + + if (!pContext) + asyncIrp = FALSE; + if (!asyncIrp) { - smartcard_irp_device_control(smartcard, irp); - + status = smartcard_irp_device_control_call(smartcard, operation); Queue_Enqueue(smartcard->CompletedIrpQueue, (void*) irp); + free(operation); } else { - irp->thread = CreateThread(NULL, 0, - (LPTHREAD_START_ROUTINE) smartcard_process_irp_worker_proc, - irp, 0, NULL); + if (pContext) + { + MessageQueue_Post(pContext->IrpQueue, NULL, 0, (void*) operation, NULL); + } } } else diff --git a/channels/smartcard/client/smartcard_main.h b/channels/smartcard/client/smartcard_main.h index c0c95db3d..e0172d424 100644 --- a/channels/smartcard/client/smartcard_main.h +++ b/channels/smartcard/client/smartcard_main.h @@ -81,6 +81,27 @@ #define SCARD_IOCTL_GETREADERICON RDP_SCARD_CTL_CODE(67) /* SCardGetReaderIconA */ #define SCARD_IOCTL_GETDEVICETYPEID RDP_SCARD_CTL_CODE(68) /* SCardGetDeviceTypeIdA */ +typedef struct _SMARTCARD_DEVICE SMARTCARD_DEVICE; + +struct _SMARTCARD_OPERATION +{ + IRP* irp; + void* call; + UINT32 ioControlCode; + SCARDCONTEXT hContext; + SCARDHANDLE hCard; +}; +typedef struct _SMARTCARD_OPERATION SMARTCARD_OPERATION; + +struct _SMARTCARD_CONTEXT +{ + HANDLE thread; + SCARDCONTEXT hContext; + wMessageQueue* IrpQueue; + SMARTCARD_DEVICE* smartcard; +}; +typedef struct _SMARTCARD_CONTEXT SMARTCARD_CONTEXT; + struct _SMARTCARD_DEVICE { DEVICE device; @@ -91,18 +112,21 @@ struct _SMARTCARD_DEVICE char* path; HANDLE thread; + HANDLE StartedEvent; wMessageQueue* IrpQueue; wQueue* CompletedIrpQueue; wListDictionary* rgSCardContextList; wListDictionary* rgOutstandingMessages; }; -typedef struct _SMARTCARD_DEVICE SMARTCARD_DEVICE; + +SMARTCARD_CONTEXT* smartcard_context_new(SMARTCARD_DEVICE* smartcard, SCARDCONTEXT hContext); +void smartcard_context_free(SMARTCARD_CONTEXT* pContext); void smartcard_complete_irp(SMARTCARD_DEVICE* smartcard, IRP* irp); void smartcard_process_irp(SMARTCARD_DEVICE* smartcard, IRP* irp); -void smartcard_irp_device_control(SMARTCARD_DEVICE* smartcard, IRP* irp); -void smartcard_irp_device_control_peek_io_control_code(SMARTCARD_DEVICE* smartcard, IRP* irp, UINT32* ioControlCode); +UINT32 smartcard_irp_device_control_decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation); +UINT32 smartcard_irp_device_control_call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation); #include "smartcard_pack.h" diff --git a/channels/smartcard/client/smartcard_operations.c b/channels/smartcard/client/smartcard_operations.c index 5ef0dd230..a50651472 100644 --- a/channels/smartcard/client/smartcard_operations.c +++ b/channels/smartcard/client/smartcard_operations.c @@ -145,26 +145,38 @@ const char* smartcard_get_ioctl_string(UINT32 ioControlCode, BOOL funcName) return funcName ? "SCardUnknown" : "SCARD_IOCTL_UNKNOWN"; } -static UINT32 smartcard_EstablishContext(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_EstablishContext_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, EstablishContext_Call* call) +{ + UINT32 status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_establish_context_call(smartcard, irp->input, call); + + smartcard_trace_establish_context_call(smartcard, call); + + return status; +} + +static UINT32 smartcard_EstablishContext_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, EstablishContext_Call* call) { UINT32 status; SCARDCONTEXT hContext = -1; - EstablishContext_Call call; - EstablishContext_Return ret = { 0 }; + EstablishContext_Return ret; + IRP* irp = operation->irp; - status = smartcard_unpack_establish_context_call(smartcard, irp->input, &call); - - smartcard_trace_establish_context_call(smartcard, &call); - - if (status) - return status; - - status = ret.ReturnCode = SCardEstablishContext(call.dwScope, NULL, NULL, &hContext); + status = ret.ReturnCode = SCardEstablishContext(call->dwScope, NULL, NULL, &hContext); if (ret.ReturnCode == SCARD_S_SUCCESS) { + SMARTCARD_CONTEXT* pContext; void* key = (void*) (size_t) hContext; - ListDictionary_Add(smartcard->rgSCardContextList, key, NULL); + + pContext = smartcard_context_new(smartcard, hContext); + + ListDictionary_Add(smartcard->rgSCardContextList, key, (void*) pContext); } smartcard_scard_context_native_to_redir(smartcard, &(ret.hContext), hContext); @@ -179,28 +191,38 @@ static UINT32 smartcard_EstablishContext(SMARTCARD_DEVICE* smartcard, IRP* irp) return ret.ReturnCode; } -static UINT32 smartcard_ReleaseContext(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_ReleaseContext_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Context_Call* call) +{ + UINT32 status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_context_call(smartcard, irp->input, call); + + smartcard_trace_context_call(smartcard, call, "ReleaseContext"); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + + return status; +} + +static UINT32 smartcard_ReleaseContext_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Context_Call* call) { UINT32 status; - Context_Call call; Long_Return ret; - SCARDCONTEXT hContext; - status = smartcard_unpack_context_call(smartcard, irp->input, &call); - - smartcard_trace_context_call(smartcard, &call, "ReleaseContext"); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); - - status = ret.ReturnCode = SCardReleaseContext(hContext); + status = ret.ReturnCode = SCardReleaseContext(operation->hContext); if (ret.ReturnCode == SCARD_S_SUCCESS) { - void* key = (void*) (size_t) hContext; - ListDictionary_Remove(smartcard->rgSCardContextList, key); + SMARTCARD_CONTEXT* pContext; + void* key = (void*) (size_t) operation->hContext; + + pContext = (SMARTCARD_CONTEXT*) ListDictionary_Remove(smartcard->rgSCardContextList, key); + + smartcard_context_free(pContext); } smartcard_trace_long_return(smartcard, &ret, "ReleaseContext"); @@ -208,50 +230,63 @@ static UINT32 smartcard_ReleaseContext(SMARTCARD_DEVICE* smartcard, IRP* irp) return ret.ReturnCode; } -static UINT32 smartcard_IsValidContext(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_IsValidContext_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Context_Call* call) +{ + UINT32 status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_context_call(smartcard, irp->input, call); + + smartcard_trace_context_call(smartcard, call, "IsValidContext"); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + + return status; +} + +static UINT32 smartcard_IsValidContext_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Context_Call* call) { UINT32 status; - Context_Call call; Long_Return ret; - SCARDCONTEXT hContext; - status = smartcard_unpack_context_call(smartcard, irp->input, &call); - - smartcard_trace_context_call(smartcard, &call, "IsValidContext"); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); - - status = ret.ReturnCode = SCardIsValidContext(hContext); + status = ret.ReturnCode = SCardIsValidContext(operation->hContext); smartcard_trace_long_return(smartcard, &ret, "IsValidContext"); return ret.ReturnCode; } -static UINT32 smartcard_ListReadersA(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_ListReadersA_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, ListReaders_Call* call) +{ + UINT32 status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_list_readers_call(smartcard, irp->input, call); + + smartcard_trace_list_readers_call(smartcard, call, FALSE); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + + return status; +} + +static UINT32 smartcard_ListReadersA_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, ListReaders_Call* call) { UINT32 status; - SCARDCONTEXT hContext; - ListReaders_Call call; ListReaders_Return ret; LPSTR mszReaders = NULL; DWORD cchReaders = 0; - - status = smartcard_unpack_list_readers_call(smartcard, irp->input, &call); - - smartcard_trace_list_readers_call(smartcard, &call, FALSE); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); + IRP* irp = operation->irp; cchReaders = SCARD_AUTOALLOCATE; - status = ret.ReturnCode = SCardListReadersA(hContext, (LPCSTR) call.mszGroups, (LPSTR) &mszReaders, &cchReaders); + status = ret.ReturnCode = SCardListReadersA(operation->hContext, (LPCSTR) call->mszGroups, (LPSTR) &mszReaders, &cchReaders); ret.msz = (BYTE*) mszReaders; ret.cBytes = cchReaders; @@ -267,90 +302,104 @@ static UINT32 smartcard_ListReadersA(SMARTCARD_DEVICE* smartcard, IRP* irp) return status; if (mszReaders) - SCardFreeMemory(hContext, mszReaders); + SCardFreeMemory(operation->hContext, mszReaders); - if (call.mszGroups) - free(call.mszGroups); + if (call->mszGroups) + free(call->mszGroups); return ret.ReturnCode; } -static UINT32 smartcard_ListReadersW(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_ListReadersW_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, ListReaders_Call* call) +{ + UINT32 status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_list_readers_call(smartcard, irp->input, call); + + smartcard_trace_list_readers_call(smartcard, call, TRUE); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + + return status; +} + +static UINT32 smartcard_ListReadersW_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, ListReaders_Call* call) { UINT32 status; - SCARDCONTEXT hContext; - ListReaders_Call call; ListReaders_Return ret; LPWSTR mszReaders = NULL; DWORD cchReaders = 0; - - status = smartcard_unpack_list_readers_call(smartcard, irp->input, &call); - - smartcard_trace_list_readers_call(smartcard, &call, TRUE); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); + IRP* irp = operation->irp; cchReaders = SCARD_AUTOALLOCATE; - status = ret.ReturnCode = SCardListReadersW(hContext, (LPCWSTR) call.mszGroups, (LPWSTR) &mszReaders, &cchReaders); + status = ret.ReturnCode = SCardListReadersW(operation->hContext, (LPCWSTR) call->mszGroups, (LPWSTR) &mszReaders, &cchReaders); ret.msz = (BYTE*) mszReaders; ret.cBytes = cchReaders * 2; - if (status) + if (status != SCARD_S_SUCCESS) return status; smartcard_trace_list_readers_return(smartcard, &ret, TRUE); status = smartcard_pack_list_readers_return(smartcard, irp->output, &ret); - if (status) + if (status != SCARD_S_SUCCESS) return status; if (mszReaders) - SCardFreeMemory(hContext, mszReaders); + SCardFreeMemory(operation->hContext, mszReaders); - if (call.mszGroups) - free(call.mszGroups); + if (call->mszGroups) + free(call->mszGroups); return ret.ReturnCode; } -static UINT32 smartcard_GetStatusChangeA(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_GetStatusChangeA_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, GetStatusChangeA_Call* call) +{ + LONG status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_get_status_change_a_call(smartcard, irp->input, call); + + smartcard_trace_get_status_change_a_call(smartcard, call); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + + return status; +} + +static UINT32 smartcard_GetStatusChangeA_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, GetStatusChangeA_Call* call) { LONG status; UINT32 index; - SCARDCONTEXT hContext; - GetStatusChangeA_Call call; GetStatusChange_Return ret; LPSCARD_READERSTATEA rgReaderState = NULL; + IRP* irp = operation->irp; - status = smartcard_unpack_get_status_change_a_call(smartcard, irp->input, &call); - - smartcard_trace_get_status_change_a_call(smartcard, &call); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); - - status = ret.ReturnCode = SCardGetStatusChangeA(hContext, call.dwTimeOut, call.rgReaderStates, call.cReaders); + status = ret.ReturnCode = SCardGetStatusChangeA(operation->hContext, call->dwTimeOut, call->rgReaderStates, call->cReaders); if (status && (status != SCARD_E_TIMEOUT) && (status != SCARD_E_CANCELLED)) return status; - ret.cReaders = call.cReaders; + ret.cReaders = call->cReaders; ret.rgReaderStates = (ReaderState_Return*) calloc(ret.cReaders, sizeof(ReaderState_Return)); for (index = 0; index < ret.cReaders; index++) { - ret.rgReaderStates[index].dwCurrentState = call.rgReaderStates[index].dwCurrentState; - ret.rgReaderStates[index].dwEventState = call.rgReaderStates[index].dwEventState; - ret.rgReaderStates[index].cbAtr = call.rgReaderStates[index].cbAtr; - CopyMemory(&(ret.rgReaderStates[index].rgbAtr), &(call.rgReaderStates[index].rgbAtr), 32); + ret.rgReaderStates[index].dwCurrentState = call->rgReaderStates[index].dwCurrentState; + ret.rgReaderStates[index].dwEventState = call->rgReaderStates[index].dwEventState; + ret.rgReaderStates[index].cbAtr = call->rgReaderStates[index].cbAtr; + CopyMemory(&(ret.rgReaderStates[index].rgbAtr), &(call->rgReaderStates[index].rgbAtr), 32); } smartcard_trace_get_status_change_return(smartcard, &ret, FALSE); @@ -360,16 +409,16 @@ static UINT32 smartcard_GetStatusChangeA(SMARTCARD_DEVICE* smartcard, IRP* irp) if (status) return status; - if (call.rgReaderStates) + if (call->rgReaderStates) { - for (index = 0; index < call.cReaders; index++) + for (index = 0; index < call->cReaders; index++) { - rgReaderState = &call.rgReaderStates[index]; + rgReaderState = &call->rgReaderStates[index]; if (rgReaderState->szReader) free((void*) rgReaderState->szReader); } - free(call.rgReaderStates); + free(call->rgReaderStates); } free(ret.rgReaderStates); @@ -377,38 +426,45 @@ static UINT32 smartcard_GetStatusChangeA(SMARTCARD_DEVICE* smartcard, IRP* irp) return ret.ReturnCode; } -static UINT32 smartcard_GetStatusChangeW(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_GetStatusChangeW_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, GetStatusChangeW_Call* call) +{ + LONG status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_get_status_change_w_call(smartcard, irp->input, call); + + smartcard_trace_get_status_change_w_call(smartcard, call); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + + return status; +} + +static UINT32 smartcard_GetStatusChangeW_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, GetStatusChangeW_Call* call) { LONG status; UINT32 index; - SCARDCONTEXT hContext; - GetStatusChangeW_Call call; GetStatusChange_Return ret; LPSCARD_READERSTATEW rgReaderState = NULL; + IRP* irp = operation->irp; - status = smartcard_unpack_get_status_change_w_call(smartcard, irp->input, &call); - - smartcard_trace_get_status_change_w_call(smartcard, &call); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); - - status = ret.ReturnCode = SCardGetStatusChangeW(hContext, call.dwTimeOut, call.rgReaderStates, call.cReaders); + status = ret.ReturnCode = SCardGetStatusChangeW(operation->hContext, call->dwTimeOut, call->rgReaderStates, call->cReaders); if (status && (status != SCARD_E_TIMEOUT) && (status != SCARD_E_CANCELLED)) return status; - ret.cReaders = call.cReaders; + ret.cReaders = call->cReaders; ret.rgReaderStates = (ReaderState_Return*) calloc(ret.cReaders, sizeof(ReaderState_Return)); for (index = 0; index < ret.cReaders; index++) { - ret.rgReaderStates[index].dwCurrentState = call.rgReaderStates[index].dwCurrentState; - ret.rgReaderStates[index].dwEventState = call.rgReaderStates[index].dwEventState; - ret.rgReaderStates[index].cbAtr = call.rgReaderStates[index].cbAtr; - CopyMemory(&(ret.rgReaderStates[index].rgbAtr), &(call.rgReaderStates[index].rgbAtr), 32); + ret.rgReaderStates[index].dwCurrentState = call->rgReaderStates[index].dwCurrentState; + ret.rgReaderStates[index].dwEventState = call->rgReaderStates[index].dwEventState; + ret.rgReaderStates[index].cbAtr = call->rgReaderStates[index].cbAtr; + CopyMemory(&(ret.rgReaderStates[index].rgbAtr), &(call->rgReaderStates[index].rgbAtr), 32); } smartcard_trace_get_status_change_return(smartcard, &ret, TRUE); @@ -418,16 +474,16 @@ static UINT32 smartcard_GetStatusChangeW(SMARTCARD_DEVICE* smartcard, IRP* irp) if (status) return status; - if (call.rgReaderStates) + if (call->rgReaderStates) { - for (index = 0; index < call.cReaders; index++) + for (index = 0; index < call->cReaders; index++) { - rgReaderState = &call.rgReaderStates[index]; + rgReaderState = &call->rgReaderStates[index]; if (rgReaderState->szReader) free((void*) rgReaderState->szReader); } - free(call.rgReaderStates); + free(call->rgReaderStates); } free(ret.rgReaderStates); @@ -435,61 +491,72 @@ static UINT32 smartcard_GetStatusChangeW(SMARTCARD_DEVICE* smartcard, IRP* irp) return ret.ReturnCode; } -static UINT32 smartcard_Cancel(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_Cancel_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Context_Call* call) +{ + LONG status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_context_call(smartcard, irp->input, call); + + smartcard_trace_context_call(smartcard, call, "Cancel"); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + + return status; +} + +static UINT32 smartcard_Cancel_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Context_Call* call) { LONG status; - SCARDCONTEXT hContext; - Context_Call call; Long_Return ret; - status = smartcard_unpack_context_call(smartcard, irp->input, &call); - - smartcard_trace_context_call(smartcard, &call, "Cancel"); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); - - status = ret.ReturnCode = SCardCancel(hContext); + status = ret.ReturnCode = SCardCancel(operation->hContext); smartcard_trace_long_return(smartcard, &ret, "Cancel"); return ret.ReturnCode; } -UINT32 smartcard_ConnectA(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_ConnectA_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, ConnectA_Call* call) +{ + LONG status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_connect_a_call(smartcard, irp->input, call); + + smartcard_trace_connect_a_call(smartcard, call); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->Common.hContext)); + + return status; +} + +static UINT32 smartcard_ConnectA_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, ConnectA_Call* call) { LONG status; SCARDHANDLE hCard; - SCARDCONTEXT hContext; - ConnectA_Call call; Connect_Return ret; + IRP* irp = operation->irp; - call.szReader = NULL; - - status = smartcard_unpack_connect_a_call(smartcard, irp->input, &call); - - smartcard_trace_connect_a_call(smartcard, &call); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.Common.hContext)); - - if ((call.Common.dwPreferredProtocols == SCARD_PROTOCOL_UNDEFINED) && - (call.Common.dwShareMode != SCARD_SHARE_DIRECT)) + if ((call->Common.dwPreferredProtocols == SCARD_PROTOCOL_UNDEFINED) && + (call->Common.dwShareMode != SCARD_SHARE_DIRECT)) { - call.Common.dwPreferredProtocols = SCARD_PROTOCOL_Tx; + call->Common.dwPreferredProtocols = SCARD_PROTOCOL_Tx; } - status = ret.ReturnCode = SCardConnectA(hContext, (char*) call.szReader, call.Common.dwShareMode, - call.Common.dwPreferredProtocols, &hCard, &ret.dwActiveProtocol); + status = ret.ReturnCode = SCardConnectA(operation->hContext, (char*) call->szReader, call->Common.dwShareMode, + call->Common.dwPreferredProtocols, &hCard, &ret.dwActiveProtocol); if (status) return status; - smartcard_scard_context_native_to_redir(smartcard, &(ret.hContext), hContext); + smartcard_scard_context_native_to_redir(smartcard, &(ret.hContext), operation->hContext); smartcard_scard_handle_native_to_redir(smartcard, &(ret.hCard), hCard); smartcard_trace_connect_return(smartcard, &ret); @@ -499,44 +566,49 @@ UINT32 smartcard_ConnectA(SMARTCARD_DEVICE* smartcard, IRP* irp) if (status) return status; - if (call.szReader) - free(call.szReader); + if (call->szReader) + free(call->szReader); return ret.ReturnCode; } -UINT32 smartcard_ConnectW(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_ConnectW_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, ConnectW_Call* call) +{ + LONG status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_connect_w_call(smartcard, irp->input, call); + + smartcard_trace_connect_w_call(smartcard, call); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->Common.hContext)); + + return status; +} + +static UINT32 smartcard_ConnectW_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, ConnectW_Call* call) { LONG status; - SCARDCONTEXT hContext; SCARDHANDLE hCard; - ConnectW_Call call; Connect_Return ret; + IRP* irp = operation->irp; - call.szReader = NULL; - - status = smartcard_unpack_connect_w_call(smartcard, irp->input, &call); - - smartcard_trace_connect_w_call(smartcard, &call); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.Common.hContext)); - - if ((call.Common.dwPreferredProtocols == SCARD_PROTOCOL_UNDEFINED) && - (call.Common.dwShareMode != SCARD_SHARE_DIRECT)) + if ((call->Common.dwPreferredProtocols == SCARD_PROTOCOL_UNDEFINED) && + (call->Common.dwShareMode != SCARD_SHARE_DIRECT)) { - call.Common.dwPreferredProtocols = SCARD_PROTOCOL_Tx; + call->Common.dwPreferredProtocols = SCARD_PROTOCOL_Tx; } - status = ret.ReturnCode = SCardConnectW(hContext, (WCHAR*) call.szReader, call.Common.dwShareMode, - call.Common.dwPreferredProtocols, &hCard, &ret.dwActiveProtocol); + status = ret.ReturnCode = SCardConnectW(operation->hContext, (WCHAR*) call->szReader, call->Common.dwShareMode, + call->Common.dwPreferredProtocols, &hCard, &ret.dwActiveProtocol); if (status) return status; - smartcard_scard_context_native_to_redir(smartcard, &(ret.hContext), hContext); + smartcard_scard_context_native_to_redir(smartcard, &(ret.hContext), operation->hContext); smartcard_scard_handle_native_to_redir(smartcard, &(ret.hCard), hCard); smartcard_trace_connect_return(smartcard, &ret); @@ -546,193 +618,211 @@ UINT32 smartcard_ConnectW(SMARTCARD_DEVICE* smartcard, IRP* irp) if (status) return status; - if (call.szReader) - free(call.szReader); + if (call->szReader) + free(call->szReader); return ret.ReturnCode; } -static UINT32 smartcard_Reconnect(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_Reconnect_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Reconnect_Call* call) +{ + LONG status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_reconnect_call(smartcard, irp->input, call); + + smartcard_trace_reconnect_call(smartcard, call); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + operation->hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call->hCard)); + + return status; +} + +static UINT32 smartcard_Reconnect_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Reconnect_Call* call) { LONG status; - SCARDHANDLE hCard; - SCARDCONTEXT hContext; - Reconnect_Call call; Reconnect_Return ret; + IRP* irp = operation->irp; - status = smartcard_unpack_reconnect_call(smartcard, irp->input, &call); - - smartcard_trace_reconnect_call(smartcard, &call); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); - hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call.hCard)); - - status = ret.ReturnCode = SCardReconnect(hCard, call.dwShareMode, - call.dwPreferredProtocols, call.dwInitialization, &ret.dwActiveProtocol); - - if (status) - return status; + status = ret.ReturnCode = SCardReconnect(operation->hCard, call->dwShareMode, + call->dwPreferredProtocols, call->dwInitialization, &ret.dwActiveProtocol); smartcard_trace_reconnect_return(smartcard, &ret); status = smartcard_pack_reconnect_return(smartcard, irp->output, &ret); - if (status) + if (status != SCARD_S_SUCCESS) return status; return ret.ReturnCode; } -static UINT32 smartcard_Disconnect(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_Disconnect_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, HCardAndDisposition_Call* call) +{ + LONG status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_hcard_and_disposition_call(smartcard, irp->input, call); + + smartcard_trace_hcard_and_disposition_call(smartcard, call, "Disconnect"); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + operation->hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call->hCard)); + + return status; +} + +static UINT32 smartcard_Disconnect_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, HCardAndDisposition_Call* call) { LONG status; - SCARDHANDLE hCard; - SCARDCONTEXT hContext; - HCardAndDisposition_Call call; Long_Return ret; - status = smartcard_unpack_hcard_and_disposition_call(smartcard, irp->input, &call); - - smartcard_trace_hcard_and_disposition_call(smartcard, &call, "Disconnect"); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); - hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call.hCard)); - - status = ret.ReturnCode = SCardDisconnect(hCard, call.dwDisposition); + status = ret.ReturnCode = SCardDisconnect(operation->hCard, call->dwDisposition); smartcard_trace_long_return(smartcard, &ret, "Disconnect"); - if (status) + if (status != SCARD_S_SUCCESS) return status; return ret.ReturnCode; } -static UINT32 smartcard_BeginTransaction(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_BeginTransaction_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, HCardAndDisposition_Call* call) +{ + LONG status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_hcard_and_disposition_call(smartcard, irp->input, call); + + smartcard_trace_hcard_and_disposition_call(smartcard, call, "BeginTransaction"); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + operation->hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call->hCard)); + + return status; +} + +static UINT32 smartcard_BeginTransaction_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, HCardAndDisposition_Call* call) { LONG status; - SCARDHANDLE hCard; - SCARDCONTEXT hContext; - HCardAndDisposition_Call call; Long_Return ret; - status = smartcard_unpack_hcard_and_disposition_call(smartcard, irp->input, &call); - - smartcard_trace_hcard_and_disposition_call(smartcard, &call, "BeginTransaction"); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); - hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call.hCard)); - - status = ret.ReturnCode = SCardBeginTransaction(hCard); + status = ret.ReturnCode = SCardBeginTransaction(operation->hCard); smartcard_trace_long_return(smartcard, &ret, "BeginTransaction"); - if (status) - return status; - return ret.ReturnCode; } -static UINT32 smartcard_EndTransaction(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_EndTransaction_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, HCardAndDisposition_Call* call) +{ + LONG status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_hcard_and_disposition_call(smartcard, irp->input, call); + + smartcard_trace_hcard_and_disposition_call(smartcard, call, "EndTransaction"); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + operation->hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call->hCard)); + + return status; +} + +static UINT32 smartcard_EndTransaction_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, HCardAndDisposition_Call* call) { LONG status; - SCARDHANDLE hCard; - SCARDCONTEXT hContext; - HCardAndDisposition_Call call; Long_Return ret; - status = smartcard_unpack_hcard_and_disposition_call(smartcard, irp->input, &call); - - smartcard_trace_hcard_and_disposition_call(smartcard, &call, "EndTransaction"); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); - hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call.hCard)); - - status = ret.ReturnCode = SCardEndTransaction(hCard, call.dwDisposition); + status = ret.ReturnCode = SCardEndTransaction(operation->hCard, call->dwDisposition); smartcard_trace_long_return(smartcard, &ret, "EndTransaction"); - if (status) - return status; - return ret.ReturnCode; } -static UINT32 smartcard_State(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_State_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, State_Call* call) +{ + LONG status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_state_call(smartcard, irp->input, call); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + operation->hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call->hCard)); + + return status; +} + +static UINT32 smartcard_State_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, State_Call* call) { LONG status; - SCARDHANDLE hCard; - SCARDCONTEXT hContext; - State_Call call; State_Return ret; - - status = smartcard_unpack_state_call(smartcard, irp->input, &call); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); - hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call.hCard)); + IRP* irp = operation->irp; ret.cbAtrLen = SCARD_ATR_LENGTH; - status = ret.ReturnCode = SCardState(hCard, &ret.dwState, &ret.dwProtocol, (BYTE*) &ret.rgAtr, &ret.cbAtrLen); - - if (ret.ReturnCode) - { - Stream_Zero(irp->output, 256); - return ret.ReturnCode; - } + status = ret.ReturnCode = SCardState(operation->hCard, &ret.dwState, &ret.dwProtocol, (BYTE*) &ret.rgAtr, &ret.cbAtrLen); status = smartcard_pack_state_return(smartcard, irp->output, &ret); - if (status) + if (status != SCARD_S_SUCCESS) return status; return ret.ReturnCode; } -static DWORD smartcard_StatusA(SMARTCARD_DEVICE* smartcard, IRP* irp) +static DWORD smartcard_StatusA_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Status_Call* call) +{ + LONG status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_status_call(smartcard, irp->input, call); + + smartcard_trace_status_call(smartcard, call, FALSE); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + operation->hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call->hCard)); + + return status; +} + +static DWORD smartcard_StatusA_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Status_Call* call) { LONG status; - SCARDHANDLE hCard; - SCARDCONTEXT hContext; - Status_Call call; Status_Return ret = { 0 }; DWORD cchReaderLen = 0; LPSTR mszReaderNames = NULL; + IRP* irp = operation->irp; - status = smartcard_unpack_status_call(smartcard, irp->input, &call); + if (call->cbAtrLen > 32) + call->cbAtrLen = 32; - smartcard_trace_status_call(smartcard, &call, FALSE); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); - hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call.hCard)); - - if (call.cbAtrLen > 32) - call.cbAtrLen = 32; - - ret.cbAtrLen = call.cbAtrLen; + ret.cbAtrLen = call->cbAtrLen; ZeroMemory(ret.pbAtr, 32); cchReaderLen = SCARD_AUTOALLOCATE; - status = ret.ReturnCode = SCardStatusA(hCard, (LPSTR) &mszReaderNames, &cchReaderLen, + status = ret.ReturnCode = SCardStatusA(operation->hCard, (LPSTR) &mszReaderNames, &cchReaderLen, &ret.dwState, &ret.dwProtocol, (BYTE*) &ret.pbAtr, &ret.cbAtrLen); if (status == SCARD_S_SUCCESS) @@ -745,44 +835,50 @@ static DWORD smartcard_StatusA(SMARTCARD_DEVICE* smartcard, IRP* irp) status = smartcard_pack_status_return(smartcard, irp->output, &ret); - if (status) + if (status != SCARD_S_SUCCESS) return status; if (mszReaderNames) - SCardFreeMemory(hContext, mszReaderNames); + SCardFreeMemory(operation->hContext, mszReaderNames); return ret.ReturnCode; } -static DWORD smartcard_StatusW(SMARTCARD_DEVICE* smartcard, IRP* irp) +static DWORD smartcard_StatusW_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Status_Call* call) +{ + LONG status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_status_call(smartcard, irp->input, call); + + smartcard_trace_status_call(smartcard, call, TRUE); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + operation->hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call->hCard)); + + return status; +} + +static DWORD smartcard_StatusW_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Status_Call* call) { LONG status; - SCARDHANDLE hCard; - SCARDCONTEXT hContext; - Status_Call call; Status_Return ret; DWORD cchReaderLen = 0; LPWSTR mszReaderNames = NULL; + IRP* irp = operation->irp; - status = smartcard_unpack_status_call(smartcard, irp->input, &call); + if (call->cbAtrLen > 32) + call->cbAtrLen = 32; - smartcard_trace_status_call(smartcard, &call, TRUE); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); - hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call.hCard)); - - if (call.cbAtrLen > 32) - call.cbAtrLen = 32; - - ret.cbAtrLen = call.cbAtrLen; + ret.cbAtrLen = call->cbAtrLen; ZeroMemory(ret.pbAtr, 32); cchReaderLen = SCARD_AUTOALLOCATE; - status = ret.ReturnCode = SCardStatusW(hCard, (LPWSTR) &mszReaderNames, &cchReaderLen, + status = ret.ReturnCode = SCardStatusW(operation->hCard, (LPWSTR) &mszReaderNames, &cchReaderLen, &ret.dwState, &ret.dwProtocol, (BYTE*) &ret.pbAtr, &ret.cbAtrLen); ret.mszReaderNames = (BYTE*) mszReaderNames; @@ -792,155 +888,170 @@ static DWORD smartcard_StatusW(SMARTCARD_DEVICE* smartcard, IRP* irp) status = smartcard_pack_status_return(smartcard, irp->output, &ret); - if (status) + if (status != SCARD_S_SUCCESS) return status; if (mszReaderNames) - SCardFreeMemory(hContext, mszReaderNames); + SCardFreeMemory(operation->hContext, mszReaderNames); return ret.ReturnCode; } -static UINT32 smartcard_Transmit(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_Transmit_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Transmit_Call* call) +{ + LONG status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_transmit_call(smartcard, irp->input, call); + + smartcard_trace_transmit_call(smartcard, call); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + operation->hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call->hCard)); + + return status; +} + +static UINT32 smartcard_Transmit_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Transmit_Call* call) { LONG status; - SCARDHANDLE hCard; - SCARDCONTEXT hContext; - Transmit_Call call; Transmit_Return ret; - - status = smartcard_unpack_transmit_call(smartcard, irp->input, &call); - - smartcard_trace_transmit_call(smartcard, &call); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); - hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call.hCard)); + IRP* irp = operation->irp; ret.cbRecvLength = 0; ret.pbRecvBuffer = NULL; - if (call.cbRecvLength && !call.fpbRecvBufferIsNULL) + if (call->cbRecvLength && !call->fpbRecvBufferIsNULL) { - if (call.cbRecvLength >= 66560) - call.cbRecvLength = 66560; + if (call->cbRecvLength >= 66560) + call->cbRecvLength = 66560; - ret.cbRecvLength = call.cbRecvLength; + ret.cbRecvLength = call->cbRecvLength; ret.pbRecvBuffer = (BYTE*) malloc(ret.cbRecvLength); } - ret.pioRecvPci = call.pioRecvPci; + ret.pioRecvPci = call->pioRecvPci; - status = ret.ReturnCode = SCardTransmit(hCard, call.pioSendPci, call.pbSendBuffer, - call.cbSendLength, ret.pioRecvPci, ret.pbRecvBuffer, &(ret.cbRecvLength)); - - if (status) - return status; + status = ret.ReturnCode = SCardTransmit(operation->hCard, call->pioSendPci, call->pbSendBuffer, + call->cbSendLength, ret.pioRecvPci, ret.pbRecvBuffer, &(ret.cbRecvLength)); smartcard_trace_transmit_return(smartcard, &ret); status = smartcard_pack_transmit_return(smartcard, irp->output, &ret); - if (status) + if (status != SCARD_S_SUCCESS) return status; - if (call.pbSendBuffer) - free(call.pbSendBuffer); + if (call->pbSendBuffer) + free(call->pbSendBuffer); if (ret.pbRecvBuffer) free(ret.pbRecvBuffer); - if (call.pioSendPci) - free(call.pioSendPci); - if (call.pioRecvPci) - free(call.pioRecvPci); + if (call->pioSendPci) + free(call->pioSendPci); + if (call->pioRecvPci) + free(call->pioRecvPci); return ret.ReturnCode; } -static UINT32 smartcard_Control(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_Control_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Control_Call* call) +{ + LONG status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_control_call(smartcard, irp->input, call); + + smartcard_trace_control_call(smartcard, call); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + operation->hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call->hCard)); + + return status; +} + +static UINT32 smartcard_Control_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Control_Call* call) { LONG status; - SCARDHANDLE hCard; - SCARDCONTEXT hContext; - Control_Call call; Control_Return ret; + IRP* irp = operation->irp; - status = smartcard_unpack_control_call(smartcard, irp->input, &call); + ret.cbOutBufferSize = call->cbOutBufferSize; + ret.pvOutBuffer = (BYTE*) malloc(call->cbOutBufferSize); - smartcard_trace_control_call(smartcard, &call); + if (!ret.pvOutBuffer) + return SCARD_E_NO_MEMORY; - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); - hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call.hCard)); - - ret.cbOutBufferSize = call.cbOutBufferSize; - ret.pvOutBuffer = (BYTE*) malloc(call.cbOutBufferSize); - - status = ret.ReturnCode = SCardControl(hCard, - call.dwControlCode, call.pvInBuffer, call.cbInBufferSize, - ret.pvOutBuffer, call.cbOutBufferSize, &ret.cbOutBufferSize); - - if (status) - return status; + status = ret.ReturnCode = SCardControl(operation->hCard, + call->dwControlCode, call->pvInBuffer, call->cbInBufferSize, + ret.pvOutBuffer, call->cbOutBufferSize, &ret.cbOutBufferSize); smartcard_trace_control_return(smartcard, &ret); status = smartcard_pack_control_return(smartcard, irp->output, &ret); - if (status) + if (status != SCARD_S_SUCCESS) return status; - if (call.pvInBuffer) - free(call.pvInBuffer); + if (call->pvInBuffer) + free(call->pvInBuffer); if (ret.pvOutBuffer) free(ret.pvOutBuffer); return ret.ReturnCode; } -static UINT32 smartcard_GetAttrib(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_GetAttrib_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, GetAttrib_Call* call) +{ + LONG status; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; + + status = smartcard_unpack_get_attrib_call(smartcard, irp->input, call); + + smartcard_trace_get_attrib_call(smartcard, call); + + operation->hContext = smartcard_scard_context_native_from_redir(smartcard, &(call->hContext)); + operation->hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call->hCard)); + + return status; +} + +static UINT32 smartcard_GetAttrib_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, GetAttrib_Call* call) { LONG status; DWORD cbAttrLen; - SCARDHANDLE hCard; - SCARDCONTEXT hContext; - GetAttrib_Call call; GetAttrib_Return ret; - - status = smartcard_unpack_get_attrib_call(smartcard, irp->input, &call); - - smartcard_trace_get_attrib_call(smartcard, &call); - - if (status) - return status; - - hContext = smartcard_scard_context_native_from_redir(smartcard, &(call.hContext)); - hCard = smartcard_scard_handle_native_from_redir(smartcard, &(call.hCard)); + IRP* irp = operation->irp; ret.pbAttr = NULL; - if (call.fpbAttrIsNULL) - call.cbAttrLen = 0; + if (call->fpbAttrIsNULL) + call->cbAttrLen = 0; - if (call.cbAttrLen) - ret.pbAttr = (BYTE*) malloc(call.cbAttrLen); + if (call->cbAttrLen) + ret.pbAttr = (BYTE*) malloc(call->cbAttrLen); - cbAttrLen = call.cbAttrLen; + cbAttrLen = call->cbAttrLen; - status = ret.ReturnCode = SCardGetAttrib(hCard, call.dwAttrId, ret.pbAttr, &cbAttrLen); + status = ret.ReturnCode = SCardGetAttrib(operation->hCard, call->dwAttrId, ret.pbAttr, &cbAttrLen); ret.cbAttrLen = cbAttrLen; - smartcard_trace_get_attrib_return(smartcard, &ret, call.dwAttrId); + smartcard_trace_get_attrib_return(smartcard, &ret, call->dwAttrId); if (ret.ReturnCode) { WLog_Print(smartcard->log, WLOG_WARN, "SCardGetAttrib: %s (0x%08X) cbAttrLen: %d\n", - SCardGetAttributeString(call.dwAttrId), call.dwAttrId, call.cbAttrLen); + SCardGetAttributeString(call->dwAttrId), call->dwAttrId, call->cbAttrLen); Stream_Zero(irp->output, 256); return ret.ReturnCode; @@ -948,7 +1059,7 @@ static UINT32 smartcard_GetAttrib(SMARTCARD_DEVICE* smartcard, IRP* irp) status = smartcard_pack_get_attrib_return(smartcard, irp->output, &ret); - if (status) + if (status != SCARD_S_SUCCESS) return status; free(ret.pbAttr); @@ -956,10 +1067,12 @@ static UINT32 smartcard_GetAttrib(SMARTCARD_DEVICE* smartcard, IRP* irp) return ret.ReturnCode; } -static UINT32 smartcard_AccessStartedEvent(SMARTCARD_DEVICE* smartcard, IRP* irp) +static UINT32 smartcard_AccessStartedEvent_Decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Long_Call* call) { - UINT32 status; - Long_Return ret; + IRP* irp = operation->irp; + + if (!call) + return STATUS_NO_MEMORY; if (Stream_GetRemainingLength(irp->input) < 4) { @@ -968,39 +1081,36 @@ static UINT32 smartcard_AccessStartedEvent(SMARTCARD_DEVICE* smartcard, IRP* irp return SCARD_F_INTERNAL_ERROR; } - Stream_Seek(irp->input, 4); /* Unused (4 bytes) */ - + Stream_Read_UINT32(irp->input, call->LongValue); /* Unused (4 bytes) */ + + return SCARD_S_SUCCESS; +} + +static UINT32 smartcard_AccessStartedEvent_Call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation, Long_Call* call) +{ + UINT32 status; + Long_Return ret; + status = ret.ReturnCode = SCARD_S_SUCCESS; + if (!smartcard->StartedEvent) + smartcard->StartedEvent = SCardAccessStartedEvent(); + + if (!smartcard->StartedEvent) + status = ret.ReturnCode = SCARD_E_NO_SERVICE; + return status; } -void smartcard_irp_device_control_peek_io_control_code(SMARTCARD_DEVICE* smartcard, IRP* irp, UINT32* ioControlCode) +UINT32 smartcard_irp_device_control_decode(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation) { - *ioControlCode = 0; - - if (Stream_GetRemainingLength(irp->input) < 32) - { - WLog_Print(smartcard->log, WLOG_WARN, "Device Control Request is too short: %d", - (int) Stream_GetRemainingLength(irp->input)); - return; - } - - Stream_Seek_UINT32(irp->input); /* OutputBufferLength (4 bytes) */ - Stream_Seek_UINT32(irp->input); /* InputBufferLength (4 bytes) */ - Stream_Read_UINT32(irp->input, *ioControlCode); /* IoControlCode (4 bytes) */ - Stream_Rewind(irp->input, (4 + 4 + 4)); -} - -void smartcard_irp_device_control(SMARTCARD_DEVICE* smartcard, IRP* irp) -{ - UINT32 result; UINT32 status; UINT32 offset; + void* call = NULL; UINT32 ioControlCode; UINT32 outputBufferLength; UINT32 inputBufferLength; - UINT32 objectBufferLength; + IRP* irp = operation->irp; /* Device Control Request */ @@ -1008,7 +1118,7 @@ void smartcard_irp_device_control(SMARTCARD_DEVICE* smartcard, IRP* irp) { WLog_Print(smartcard->log, WLOG_WARN, "Device Control Request is too short: %d", (int) Stream_GetRemainingLength(irp->input)); - return; + return SCARD_F_INTERNAL_ERROR; } Stream_Read_UINT32(irp->input, outputBufferLength); /* OutputBufferLength (4 bytes) */ @@ -1016,12 +1126,14 @@ void smartcard_irp_device_control(SMARTCARD_DEVICE* smartcard, IRP* irp) Stream_Read_UINT32(irp->input, ioControlCode); /* IoControlCode (4 bytes) */ Stream_Seek(irp->input, 20); /* Padding (20 bytes) */ + operation->ioControlCode = ioControlCode; + if (Stream_Length(irp->input) != (Stream_GetPosition(irp->input) + inputBufferLength)) { WLog_Print(smartcard->log, WLOG_WARN, "InputBufferLength mismatch: Actual: %d Expected: %d\n", Stream_Length(irp->input), Stream_GetPosition(irp->input) + inputBufferLength); - return; + return SCARD_F_INTERNAL_ERROR; } WLog_Print(smartcard->log, WLOG_DEBUG, "%s (0x%08X) FileId: %d CompletionId: %d", @@ -1038,14 +1150,296 @@ void smartcard_irp_device_control(SMARTCARD_DEVICE* smartcard, IRP* irp) status = smartcard_unpack_common_type_header(smartcard, irp->input); if (status) - return; + return SCARD_F_INTERNAL_ERROR; status = smartcard_unpack_private_type_header(smartcard, irp->input); if (status) - return; + return SCARD_F_INTERNAL_ERROR; } + /* Decode */ + + switch (ioControlCode) + { + case SCARD_IOCTL_ESTABLISHCONTEXT: + call = calloc(1, sizeof(EstablishContext_Call)); + status = smartcard_EstablishContext_Decode(smartcard, operation, (EstablishContext_Call*) call); + break; + + case SCARD_IOCTL_RELEASECONTEXT: + call = calloc(1, sizeof(Context_Call)); + status = smartcard_ReleaseContext_Decode(smartcard, operation, (Context_Call*) call); + break; + + case SCARD_IOCTL_ISVALIDCONTEXT: + call = calloc(1, sizeof(Context_Call)); + status = smartcard_IsValidContext_Decode(smartcard, operation, (Context_Call*) call); + break; + + case SCARD_IOCTL_LISTREADERGROUPSA: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_LISTREADERGROUPSW: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_LISTREADERSA: + call = calloc(1, sizeof(ListReaders_Call)); + status = smartcard_ListReadersA_Decode(smartcard, operation, (ListReaders_Call*) call); + break; + + case SCARD_IOCTL_LISTREADERSW: + call = calloc(1, sizeof(ListReaders_Call)); + status = smartcard_ListReadersW_Decode(smartcard, operation, (ListReaders_Call*) call); + break; + + case SCARD_IOCTL_INTRODUCEREADERGROUPA: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_INTRODUCEREADERGROUPW: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_FORGETREADERGROUPA: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_FORGETREADERGROUPW: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_INTRODUCEREADERA: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_INTRODUCEREADERW: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_FORGETREADERA: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_FORGETREADERW: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_ADDREADERTOGROUPA: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_ADDREADERTOGROUPW: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_REMOVEREADERFROMGROUPA: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_REMOVEREADERFROMGROUPW: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_LOCATECARDSA: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_LOCATECARDSW: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_GETSTATUSCHANGEA: + call = calloc(1, sizeof(GetStatusChangeA_Call)); + status = smartcard_GetStatusChangeA_Decode(smartcard, operation, (GetStatusChangeA_Call*) call); + break; + + case SCARD_IOCTL_GETSTATUSCHANGEW: + call = calloc(1, sizeof(GetStatusChangeW_Call)); + status = smartcard_GetStatusChangeW_Decode(smartcard, operation, (GetStatusChangeW_Call*) call); + break; + + case SCARD_IOCTL_CANCEL: + call = calloc(1, sizeof(Context_Call)); + status = smartcard_Cancel_Decode(smartcard, operation, (Context_Call*) call); + break; + + case SCARD_IOCTL_CONNECTA: + call = calloc(1, sizeof(ConnectA_Call)); + status = smartcard_ConnectA_Decode(smartcard, operation, (ConnectA_Call*) call); + break; + + case SCARD_IOCTL_CONNECTW: + call = calloc(1, sizeof(ConnectW_Call)); + status = smartcard_ConnectW_Decode(smartcard, operation, (ConnectW_Call*) call); + break; + + case SCARD_IOCTL_RECONNECT: + call = calloc(1, sizeof(Reconnect_Call)); + status = smartcard_Reconnect_Decode(smartcard, operation, (Reconnect_Call*) call); + break; + + case SCARD_IOCTL_DISCONNECT: + call = calloc(1, sizeof(HCardAndDisposition_Call)); + status = smartcard_Disconnect_Decode(smartcard, operation, (HCardAndDisposition_Call*) call); + break; + + case SCARD_IOCTL_BEGINTRANSACTION: + call = calloc(1, sizeof(HCardAndDisposition_Call)); + status = smartcard_BeginTransaction_Decode(smartcard, operation, (HCardAndDisposition_Call*) call); + break; + + case SCARD_IOCTL_ENDTRANSACTION: + call = calloc(1, sizeof(HCardAndDisposition_Call)); + status = smartcard_EndTransaction_Decode(smartcard, operation, (HCardAndDisposition_Call*) call); + break; + + case SCARD_IOCTL_STATE: + call = calloc(1, sizeof(State_Call)); + status = smartcard_State_Decode(smartcard, operation, (State_Call*) call); + break; + + case SCARD_IOCTL_STATUSA: + call = calloc(1, sizeof(Status_Call)); + status = smartcard_StatusA_Decode(smartcard, operation, (Status_Call*) call); + break; + + case SCARD_IOCTL_STATUSW: + call = calloc(1, sizeof(Status_Call)); + status = smartcard_StatusW_Decode(smartcard, operation, (Status_Call*) call); + break; + + case SCARD_IOCTL_TRANSMIT: + call = calloc(1, sizeof(Transmit_Call)); + status = smartcard_Transmit_Decode(smartcard, operation, (Transmit_Call*) call); + break; + + case SCARD_IOCTL_CONTROL: + call = calloc(1, sizeof(Control_Call)); + status = smartcard_Control_Decode(smartcard, operation, (Control_Call*) call); + break; + + case SCARD_IOCTL_GETATTRIB: + call = calloc(1, sizeof(GetAttrib_Call)); + status = smartcard_GetAttrib_Decode(smartcard, operation, (GetAttrib_Call*) call); + break; + + case SCARD_IOCTL_SETATTRIB: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_ACCESSSTARTEDEVENT: + call = calloc(1, sizeof(Long_Call)); + status = smartcard_AccessStartedEvent_Decode(smartcard, operation, (Long_Call*) call); + break; + + case SCARD_IOCTL_LOCATECARDSBYATRA: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_LOCATECARDSBYATRW: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_READCACHEA: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_READCACHEW: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_WRITECACHEA: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_WRITECACHEW: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_GETTRANSMITCOUNT: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_RELEASESTARTEDEVENT: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_GETREADERICON: + status = SCARD_F_INTERNAL_ERROR; + break; + + case SCARD_IOCTL_GETDEVICETYPEID: + status = SCARD_F_INTERNAL_ERROR; + break; + + default: + status = SCARD_F_INTERNAL_ERROR; + break; + } + + if ((ioControlCode != SCARD_IOCTL_ACCESSSTARTEDEVENT) && + (ioControlCode != SCARD_IOCTL_RELEASESTARTEDEVENT)) + { + offset = (RDPDR_DEVICE_IO_REQUEST_LENGTH + RDPDR_DEVICE_IO_CONTROL_REQ_HDR_LENGTH); + + smartcard_unpack_read_size_align(smartcard, irp->input, + Stream_GetPosition(irp->input) - offset, 8); + } + + if (((size_t) Stream_GetPosition(irp->input)) < Stream_Length(irp->input)) + { + UINT32 difference; + + difference = (int) (Stream_Length(irp->input) - Stream_GetPosition(irp->input)); + + WLog_Print(smartcard->log, WLOG_WARN, + "IRP was not fully parsed %s (0x%08X): Actual: %d, Expected: %d, Difference: %d", + smartcard_get_ioctl_string(ioControlCode, TRUE), ioControlCode, + (int) Stream_GetPosition(irp->input), (int) Stream_Length(irp->input), difference); + + winpr_HexDump(Stream_Pointer(irp->input), difference); + } + + if (((size_t) Stream_GetPosition(irp->input)) > Stream_Length(irp->input)) + { + UINT32 difference; + + difference = (int) (Stream_GetPosition(irp->input) - Stream_Length(irp->input)); + + WLog_Print(smartcard->log, WLOG_WARN, + "IRP was parsed beyond its end %s (0x%08X): Actual: %d, Expected: %d, Difference: %d", + smartcard_get_ioctl_string(ioControlCode, TRUE), ioControlCode, + (int) Stream_GetPosition(irp->input), (int) Stream_Length(irp->input), difference); + } + + if (status != SCARD_S_SUCCESS) + { + free(call); + call = NULL; + } + + operation->call = call; + + return status; +} + +UINT32 smartcard_irp_device_control_call(SMARTCARD_DEVICE* smartcard, SMARTCARD_OPERATION* operation) +{ + IRP* irp; + UINT32 result; + UINT32 offset; + ULONG_PTR* call; + UINT32 ioControlCode; + UINT32 outputBufferLength; + UINT32 objectBufferLength; + + irp = operation->irp; + call = operation->call; + ioControlCode = operation->ioControlCode; + /** * [MS-RDPESC] 3.2.5.1: Sending Outgoing Messages: * the output buffer length SHOULD be set to 2048 @@ -1063,18 +1457,20 @@ void smartcard_irp_device_control(SMARTCARD_DEVICE* smartcard, IRP* irp) Stream_Seek_UINT32(irp->output); /* Result (4 bytes) */ + /* Call */ + switch (ioControlCode) { case SCARD_IOCTL_ESTABLISHCONTEXT: - result = smartcard_EstablishContext(smartcard, irp); + result = smartcard_EstablishContext_Call(smartcard, operation, (EstablishContext_Call*) call); break; case SCARD_IOCTL_RELEASECONTEXT: - result = smartcard_ReleaseContext(smartcard, irp); + result = smartcard_ReleaseContext_Call(smartcard, operation, (Context_Call*) call); break; case SCARD_IOCTL_ISVALIDCONTEXT: - result = smartcard_IsValidContext(smartcard, irp); + result = smartcard_IsValidContext_Call(smartcard, operation, (Context_Call*) call); break; case SCARD_IOCTL_LISTREADERGROUPSA: @@ -1086,11 +1482,11 @@ void smartcard_irp_device_control(SMARTCARD_DEVICE* smartcard, IRP* irp) break; case SCARD_IOCTL_LISTREADERSA: - result = smartcard_ListReadersA(smartcard, irp); + result = smartcard_ListReadersA_Call(smartcard, operation, (ListReaders_Call*) call); break; case SCARD_IOCTL_LISTREADERSW: - result = smartcard_ListReadersW(smartcard, irp); + result = smartcard_ListReadersW_Call(smartcard, operation, (ListReaders_Call*) call); break; case SCARD_IOCTL_INTRODUCEREADERGROUPA: @@ -1150,63 +1546,63 @@ void smartcard_irp_device_control(SMARTCARD_DEVICE* smartcard, IRP* irp) break; case SCARD_IOCTL_GETSTATUSCHANGEA: - result = smartcard_GetStatusChangeA(smartcard, irp); + result = smartcard_GetStatusChangeA_Call(smartcard, operation, (GetStatusChangeA_Call*) call); break; case SCARD_IOCTL_GETSTATUSCHANGEW: - result = smartcard_GetStatusChangeW(smartcard, irp); + result = smartcard_GetStatusChangeW_Call(smartcard, operation, (GetStatusChangeW_Call*) call); break; case SCARD_IOCTL_CANCEL: - result = smartcard_Cancel(smartcard, irp); + result = smartcard_Cancel_Call(smartcard, operation, (Context_Call*) call); break; case SCARD_IOCTL_CONNECTA: - result = smartcard_ConnectA(smartcard, irp); + result = smartcard_ConnectA_Call(smartcard, operation, (ConnectA_Call*) call); break; case SCARD_IOCTL_CONNECTW: - result = smartcard_ConnectW(smartcard, irp); + result = smartcard_ConnectW_Call(smartcard, operation, (ConnectW_Call*) call); break; case SCARD_IOCTL_RECONNECT: - result = smartcard_Reconnect(smartcard, irp); + result = smartcard_Reconnect_Call(smartcard, operation, (Reconnect_Call*) call); break; case SCARD_IOCTL_DISCONNECT: - result = smartcard_Disconnect(smartcard, irp); + result = smartcard_Disconnect_Call(smartcard, operation, (HCardAndDisposition_Call*) call); break; case SCARD_IOCTL_BEGINTRANSACTION: - result = smartcard_BeginTransaction(smartcard, irp); + result = smartcard_BeginTransaction_Call(smartcard, operation, (HCardAndDisposition_Call*) call); break; case SCARD_IOCTL_ENDTRANSACTION: - result = smartcard_EndTransaction(smartcard, irp); + result = smartcard_EndTransaction_Call(smartcard, operation, (HCardAndDisposition_Call*) call); break; case SCARD_IOCTL_STATE: - result = smartcard_State(smartcard, irp); + result = smartcard_State_Call(smartcard, operation, (State_Call*) call); break; case SCARD_IOCTL_STATUSA: - result = smartcard_StatusA(smartcard, irp); + result = smartcard_StatusA_Call(smartcard, operation, (Status_Call*) call); break; case SCARD_IOCTL_STATUSW: - result = smartcard_StatusW(smartcard, irp); + result = smartcard_StatusW_Call(smartcard, operation, (Status_Call*) call); break; case SCARD_IOCTL_TRANSMIT: - result = smartcard_Transmit(smartcard, irp); + result = smartcard_Transmit_Call(smartcard, operation, (Transmit_Call*) call); break; case SCARD_IOCTL_CONTROL: - result = smartcard_Control(smartcard, irp); + result = smartcard_Control_Call(smartcard, operation, (Control_Call*) call); break; case SCARD_IOCTL_GETATTRIB: - result = smartcard_GetAttrib(smartcard, irp); + result = smartcard_GetAttrib_Call(smartcard, operation, (GetAttrib_Call*) call); break; case SCARD_IOCTL_SETATTRIB: @@ -1214,7 +1610,7 @@ void smartcard_irp_device_control(SMARTCARD_DEVICE* smartcard, IRP* irp) break; case SCARD_IOCTL_ACCESSSTARTEDEVENT: - result = smartcard_AccessStartedEvent(smartcard, irp); + result = smartcard_AccessStartedEvent_Call(smartcard, operation, (Long_Call*) call); break; case SCARD_IOCTL_LOCATECARDSBYATRA: @@ -1262,13 +1658,21 @@ void smartcard_irp_device_control(SMARTCARD_DEVICE* smartcard, IRP* irp) break; } + free(call); + + /** + * [MS-RPCE] 2.2.6.3 Primitive Type Serialization + * The type MUST be aligned on an 8-byte boundary. If the size of the + * primitive type is not a multiple of 8 bytes, the data MUST be padded. + */ + if ((ioControlCode != SCARD_IOCTL_ACCESSSTARTEDEVENT) && (ioControlCode != SCARD_IOCTL_RELEASESTARTEDEVENT)) { - offset = (RDPDR_DEVICE_IO_REQUEST_LENGTH + RDPDR_DEVICE_IO_CONTROL_REQ_HDR_LENGTH); + offset = (RDPDR_DEVICE_IO_RESPONSE_LENGTH + RDPDR_DEVICE_IO_CONTROL_RSP_HDR_LENGTH); - smartcard_unpack_read_size_align(smartcard, irp->input, - Stream_GetPosition(irp->input) - offset, 8); + smartcard_pack_write_size_align(smartcard, irp->output, + Stream_GetPosition(irp->output) - offset, 8); } if ((result != SCARD_S_SUCCESS) && (result != SCARD_E_TIMEOUT) && @@ -1294,47 +1698,6 @@ void smartcard_irp_device_control(SMARTCARD_DEVICE* smartcard, IRP* irp) smartcard_get_ioctl_string(ioControlCode, TRUE), ioControlCode, result); } - if (((size_t) Stream_GetPosition(irp->input)) < Stream_Length(irp->input)) - { - UINT32 difference; - - difference = (int) (Stream_Length(irp->input) - Stream_GetPosition(irp->input)); - - WLog_Print(smartcard->log, WLOG_WARN, - "IRP was not fully parsed %s (0x%08X): Actual: %d, Expected: %d, Difference: %d", - smartcard_get_ioctl_string(ioControlCode, TRUE), ioControlCode, - (int) Stream_GetPosition(irp->input), (int) Stream_Length(irp->input), difference); - - winpr_HexDump(Stream_Pointer(irp->input), difference); - } - - if (((size_t) Stream_GetPosition(irp->input)) > Stream_Length(irp->input)) - { - UINT32 difference; - - difference = (int) (Stream_GetPosition(irp->input) - Stream_Length(irp->input)); - - WLog_Print(smartcard->log, WLOG_WARN, - "IRP was parsed beyond its end %s (0x%08X): Actual: %d, Expected: %d, Difference: %d", - smartcard_get_ioctl_string(ioControlCode, TRUE), ioControlCode, - (int) Stream_GetPosition(irp->input), (int) Stream_Length(irp->input), difference); - } - - /** - * [MS-RPCE] 2.2.6.3 Primitive Type Serialization - * The type MUST be aligned on an 8-byte boundary. If the size of the - * primitive type is not a multiple of 8 bytes, the data MUST be padded. - */ - - if ((ioControlCode != SCARD_IOCTL_ACCESSSTARTEDEVENT) && - (ioControlCode != SCARD_IOCTL_RELEASESTARTEDEVENT)) - { - offset = (RDPDR_DEVICE_IO_RESPONSE_LENGTH + RDPDR_DEVICE_IO_CONTROL_RSP_HDR_LENGTH); - - smartcard_pack_write_size_align(smartcard, irp->output, - Stream_GetPosition(irp->output) - offset, 8); - } - Stream_SealLength(irp->output); outputBufferLength = Stream_Length(irp->output) - RDPDR_DEVICE_IO_RESPONSE_LENGTH - 4; @@ -1350,4 +1713,7 @@ void smartcard_irp_device_control(SMARTCARD_DEVICE* smartcard, IRP* irp) Stream_Write_UINT32(irp->output, result); /* Result (4 bytes) */ Stream_SetPosition(irp->output, Stream_Length(irp->output)); + + return SCARD_S_SUCCESS; } + diff --git a/channels/smartcard/client/smartcard_pack.h b/channels/smartcard/client/smartcard_pack.h index fdf17087c..788c75795 100644 --- a/channels/smartcard/client/smartcard_pack.h +++ b/channels/smartcard/client/smartcard_pack.h @@ -24,8 +24,6 @@ #include #include -#include "smartcard_main.h" - /* interface type_scard_pack */ /* [unique][version][uuid] */ @@ -41,10 +39,15 @@ typedef struct _REDIR_SCARDHANDLE /* [size_is] */ BYTE pbHandle[8]; } REDIR_SCARDHANDLE; +typedef struct _Long_Call +{ + LONG LongValue; +} Long_Call; + typedef struct _Long_Return { LONG ReturnCode; -} Long_Return; +} Long_Return; typedef struct _longAndMultiString_Return { @@ -429,6 +432,8 @@ typedef struct _WriteCacheW_Call #define SMARTCARD_COMMON_TYPE_HEADER_LENGTH 8 #define SMARTCARD_PRIVATE_TYPE_HEADER_LENGTH 8 +#include "smartcard_main.h" + UINT32 smartcard_pack_write_size_align(SMARTCARD_DEVICE* smartcard, wStream* s, UINT32 size, UINT32 alignment); UINT32 smartcard_unpack_read_size_align(SMARTCARD_DEVICE* smartcard, wStream* s, UINT32 size, UINT32 alignment); diff --git a/client/CMakeLists.txt b/client/CMakeLists.txt index 0c4c35aa2..c171c0dfe 100644 --- a/client/CMakeLists.txt +++ b/client/CMakeLists.txt @@ -19,36 +19,38 @@ add_subdirectory(common) -if(WIN32) - add_subdirectory(Windows) -else() - if(WITH_SAMPLE) - add_subdirectory(Sample) - endif() - - if(WITH_DIRECTFB) - add_subdirectory(DirectFB) - endif() -endif() - -if(WITH_X11) - add_subdirectory(X11) -endif() - -if(APPLE) - if(IOS) - if(IS_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/iOS") - message(STATUS "Adding iOS client") - add_subdirectory(iOS) - endif() +if(FREERDP_VENDOR) + if(WIN32) + add_subdirectory(Windows) else() - add_subdirectory(Mac) - endif() -endif() + if(WITH_SAMPLE) + add_subdirectory(Sample) + endif() -if(ANDROID) - message(STATUS "Adding Android client") - add_subdirectory(Android) + if(WITH_DIRECTFB) + add_subdirectory(DirectFB) + endif() + endif() + + if(WITH_X11) + add_subdirectory(X11) + endif() + + if(APPLE) + if(IOS) + if(IS_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/iOS") + message(STATUS "Adding iOS client") + add_subdirectory(iOS) + endif() + else() + add_subdirectory(Mac) + endif() + endif() + + if(ANDROID) + message(STATUS "Adding Android client") + add_subdirectory(Android) + endif() endif() # Pick up other clients diff --git a/client/Windows/wf_cliprdr.c b/client/Windows/wf_cliprdr.c index 5a5348670..d319c83de 100644 --- a/client/Windows/wf_cliprdr.c +++ b/client/Windows/wf_cliprdr.c @@ -165,7 +165,7 @@ static void cliprdr_send_format_list(cliprdrContext *cliprdr) format_count = CountClipboardFormats(); data_size = format_count * (4 + MAX_PATH * 2); - format_data = (BYTE *)calloc(1, data_size); + format_data = (BYTE*) calloc(1, data_size); assert(format_data != NULL); while (format = EnumClipboardFormats(format)) @@ -191,7 +191,7 @@ static void cliprdr_send_format_list(cliprdrContext *cliprdr) { if (format >= CF_MAX) { - static wchar_t wName[MAX_PATH] = {0}; + static WCHAR wName[MAX_PATH] = {0}; int wLen; ZeroMemory(wName, MAX_PATH*2); @@ -216,15 +216,16 @@ static void cliprdr_send_format_list(cliprdrContext *cliprdr) if (stream_file_transferring) { - cliprdr_event->raw_format_data = (BYTE *)calloc(1, (4 + 42)); + cliprdr_event->raw_format_data_size = 4 + 42; + cliprdr_event->raw_format_data = (BYTE*) calloc(1, cliprdr_event->raw_format_data_size); format = RegisterClipboardFormatW(L"FileGroupDescriptorW"); Write_UINT32(cliprdr_event->raw_format_data, format); - wcscpy((wchar_t *)(cliprdr_event->raw_format_data + 4), L"FileGroupDescriptorW"); - cliprdr_event->raw_format_data_size = 4 + 42; + wcscpy_s((WCHAR*)(cliprdr_event->raw_format_data + 4), + (cliprdr_event->raw_format_data_size - 4) / 2, L"FileGroupDescriptorW"); } else { - cliprdr_event->raw_format_data = (BYTE *)calloc(1, len); + cliprdr_event->raw_format_data = (BYTE*) calloc(1, len); assert(cliprdr_event->raw_format_data != NULL); CopyMemory(cliprdr_event->raw_format_data, format_data, len); @@ -232,7 +233,7 @@ static void cliprdr_send_format_list(cliprdrContext *cliprdr) } free(format_data); - freerdp_channels_send_event(cliprdr->channels, (wMessage *) cliprdr_event); + freerdp_channels_send_event(cliprdr->channels, (wMessage*) cliprdr_event); } int cliprdr_send_data_request(cliprdrContext *cliprdr, UINT32 format) @@ -681,18 +682,21 @@ static BOOL wf_cliprdr_get_file_contents(wchar_t *file_name, BYTE *buffer, int p } /* path_name has a '\' at the end. e.g. c:\newfolder\, file_name is c:\newfolder\new.txt */ -static FILEDESCRIPTORW *wf_cliprdr_get_file_descriptor(wchar_t *file_name, int pathLen) +static FILEDESCRIPTORW *wf_cliprdr_get_file_descriptor(WCHAR* file_name, int pathLen) { FILEDESCRIPTORW *fd; HANDLE hFile; - fd = (FILEDESCRIPTORW *)malloc(sizeof(FILEDESCRIPTORW)); + fd = (FILEDESCRIPTORW*) malloc(sizeof(FILEDESCRIPTORW)); + if (!fd) return NULL; ZeroMemory(fd, sizeof(FILEDESCRIPTORW)); - hFile = CreateFileW(file_name, GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL | FILE_FLAG_BACKUP_SEMANTICS, NULL); + hFile = CreateFileW(file_name, GENERIC_READ, FILE_SHARE_READ, + NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL | FILE_FLAG_BACKUP_SEMANTICS, NULL); + if (hFile == INVALID_HANDLE_VALUE) { free(fd); @@ -702,13 +706,15 @@ static FILEDESCRIPTORW *wf_cliprdr_get_file_descriptor(wchar_t *file_name, int p fd->dwFlags = FD_ATTRIBUTES | FD_FILESIZE | FD_WRITESTIME | FD_PROGRESSUI; fd->dwFileAttributes = GetFileAttributes(file_name); + if (!GetFileTime(hFile, NULL, NULL, &fd->ftLastWriteTime)) { fd->dwFlags &= ~FD_WRITESTIME; } + fd->nFileSizeLow = GetFileSize(hFile, &fd->nFileSizeHigh); - wcscpy(fd->cFileName, file_name + pathLen); + wcscpy_s(fd->cFileName, sizeof(fd->cFileName) / 2, file_name + pathLen); CloseHandle(hFile); return fd; @@ -727,8 +733,8 @@ static void wf_cliprdr_array_ensure_capacity(cliprdrContext *cliprdr) static void wf_cliprdr_add_to_file_arrays(cliprdrContext *cliprdr, WCHAR *full_file_name, int pathLen) { /* add to name array */ - cliprdr->file_names[cliprdr->nFiles] = (LPWSTR)malloc(MAX_PATH); - wcscpy(cliprdr->file_names[cliprdr->nFiles], full_file_name); + cliprdr->file_names[cliprdr->nFiles] = (LPWSTR) malloc(MAX_PATH); + wcscpy_s(cliprdr->file_names[cliprdr->nFiles], MAX_PATH, full_file_name); /* add to descriptor array */ cliprdr->fileDescriptor[cliprdr->nFiles] = wf_cliprdr_get_file_descriptor(full_file_name, pathLen); @@ -843,14 +849,15 @@ static void wf_cliprdr_process_cb_data_request_event(wfContext* wfc, RDP_CB_DATA format_etc.ptd = 0; result = IDataObject_GetData(dataObj, &format_etc, &stg_medium); + if (SUCCEEDED(result)) { DEBUG_CLIPRDR("Got FileGroupDescriptorW."); globlemem = (char *)GlobalLock(stg_medium.hGlobal); uSize = GlobalSize(stg_medium.hGlobal); size = uSize; - buff = malloc(uSize); - memcpy(buff, globlemem, uSize); + buff = (char*) malloc(uSize); + CopyMemory(buff, globlemem, uSize); GlobalUnlock(stg_medium.hGlobal); ReleaseStgMedium(&stg_medium); diff --git a/client/Windows/wf_cliprdr_EnumFORMATETC.c b/client/Windows/wf_cliprdr_EnumFORMATETC.c index 48a3bfbf7..d988cf560 100644 --- a/client/Windows/wf_cliprdr_EnumFORMATETC.c +++ b/client/Windows/wf_cliprdr_EnumFORMATETC.c @@ -94,7 +94,7 @@ HRESULT STDMETHODCALLTYPE CliprdrEnumFORMATETC_Skip(IEnumFORMATETC *This, ULONG { CliprdrEnumFORMATETC *instance = (CliprdrEnumFORMATETC *)This; - if (instance->m_nIndex + celt > instance->m_nNumFormats) + if (instance->m_nIndex + (LONG) celt > instance->m_nNumFormats) return S_FALSE; instance->m_nIndex += celt; diff --git a/client/common/cmdline.c b/client/common/cmdline.c index c201e4200..5e633c4e8 100644 --- a/client/common/cmdline.c +++ b/client/common/cmdline.c @@ -1086,19 +1086,19 @@ int freerdp_client_settings_command_line_status_print(rdpSettings* settings, int layouts = freerdp_keyboard_get_layouts(RDP_KEYBOARD_LAYOUT_TYPE_STANDARD); printf("\nKeyboard Layouts\n"); for (i = 0; layouts[i].code; i++) - printf("0x%08X\t%s\n", layouts[i].code, layouts[i].name); + printf("0x%08X\t%s\n", (int) layouts[i].code, layouts[i].name); free(layouts); layouts = freerdp_keyboard_get_layouts(RDP_KEYBOARD_LAYOUT_TYPE_VARIANT); printf("\nKeyboard Layout Variants\n"); for (i = 0; layouts[i].code; i++) - printf("0x%08X\t%s\n", layouts[i].code, layouts[i].name); + printf("0x%08X\t%s\n", (int) layouts[i].code, layouts[i].name); free(layouts); layouts = freerdp_keyboard_get_layouts(RDP_KEYBOARD_LAYOUT_TYPE_IME); printf("\nKeyboard Input Method Editors (IMEs)\n"); for (i = 0; layouts[i].code; i++) - printf("0x%08X\t%s\n", layouts[i].code, layouts[i].name); + printf("0x%08X\t%s\n", (int) layouts[i].code, layouts[i].name); free(layouts); printf("\n"); diff --git a/client/common/file.c b/client/common/file.c index 65861005b..a5e06503b 100644 --- a/client/common/file.c +++ b/client/common/file.c @@ -779,18 +779,18 @@ BOOL freerdp_client_populate_settings_from_rdp_file(rdpFile* file, rdpSettings* if (~file->SessionBpp) freerdp_set_param_uint32(settings, FreeRDP_ColorDepth, file->SessionBpp); if (~file->ConnectToConsole) - freerdp_set_param_uint32(settings, FreeRDP_ConsoleSession, file->ConnectToConsole); + freerdp_set_param_bool(settings, FreeRDP_ConsoleSession, file->ConnectToConsole); if (~file->AdministrativeSession) - freerdp_set_param_uint32(settings, FreeRDP_ConsoleSession, file->AdministrativeSession); + freerdp_set_param_bool(settings, FreeRDP_ConsoleSession, file->AdministrativeSession); if (~file->NegotiateSecurityLayer) - freerdp_set_param_uint32(settings, FreeRDP_NegotiateSecurityLayer, file->NegotiateSecurityLayer); + freerdp_set_param_bool(settings, FreeRDP_NegotiateSecurityLayer, file->NegotiateSecurityLayer); if (~file->EnableCredSSPSupport) - freerdp_set_param_uint32(settings, FreeRDP_NlaSecurity, file->EnableCredSSPSupport); + freerdp_set_param_bool(settings, FreeRDP_NlaSecurity, file->EnableCredSSPSupport); if (~((size_t) file->AlternateShell)) freerdp_set_param_string(settings, FreeRDP_AlternateShell, file->AlternateShell); if (~((size_t) file->ShellWorkingDirectory)) freerdp_set_param_string(settings, FreeRDP_ShellWorkingDirectory, file->ShellWorkingDirectory); - + if (~file->ScreenModeId) { /** @@ -810,6 +810,12 @@ BOOL freerdp_client_populate_settings_from_rdp_file(rdpFile* file, rdpSettings* (file->ScreenModeId == 1) ? TRUE : FALSE); } + if (~((size_t) file->SmartSizing)) + { + freerdp_set_param_bool(settings, FreeRDP_SmartSizing, + (file->SmartSizing == 1) ? TRUE : FALSE); + } + if (~((size_t) file->LoadBalanceInfo)) { settings->LoadBalanceInfo = (BYTE*) _strdup(file->LoadBalanceInfo); @@ -864,7 +870,7 @@ BOOL freerdp_client_populate_settings_from_rdp_file(rdpFile* file, rdpSettings* freerdp_set_param_string(settings, FreeRDP_GatewayHostname, file->GatewayHostname); if (~file->GatewayUsageMethod) - freerdp_set_gateway_usage_method(settings, settings->GatewayUsageMethod); + freerdp_set_gateway_usage_method(settings, file->GatewayUsageMethod); if (~file->PromptCredentialOnce) freerdp_set_param_bool(settings, FreeRDP_GatewayUseSameCredentials, file->PromptCredentialOnce); diff --git a/include/freerdp/crypto/tls.h b/include/freerdp/crypto/tls.h index bf5521300..180007e5e 100644 --- a/include/freerdp/crypto/tls.h +++ b/include/freerdp/crypto/tls.h @@ -70,7 +70,6 @@ struct rdp_tls SSL* ssl; BIO* bio; void* tsg; - int sockfd; SSL_CTX* ctx; BYTE* PublicKey; BIO_METHOD* methods; @@ -84,17 +83,11 @@ struct rdp_tls int alertDescription; }; -FREERDP_API int tls_connect(rdpTls* tls); -FREERDP_API BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file); +FREERDP_API int tls_connect(rdpTls* tls, BIO *underlying); +FREERDP_API BOOL tls_accept(rdpTls* tls, BIO *underlying, const char* cert_file, const char* privatekey_file); FREERDP_API BOOL tls_disconnect(rdpTls* tls); -FREERDP_API int tls_read(rdpTls* tls, BYTE* data, int length); -FREERDP_API int tls_write(rdpTls* tls, BYTE* data, int length); - -FREERDP_API int tls_write_all(rdpTls* tls, BYTE* data, int length); - -FREERDP_API int tls_wait_read(rdpTls* tls); -FREERDP_API int tls_wait_write(rdpTls* tls); +FREERDP_API int tls_write_all(rdpTls* tls, const BYTE* data, int length); FREERDP_API int tls_set_alert_code(rdpTls* tls, int level, int description); diff --git a/include/freerdp/peer.h b/include/freerdp/peer.h index c89d37a07..4fbe75bfc 100644 --- a/include/freerdp/peer.h +++ b/include/freerdp/peer.h @@ -34,7 +34,10 @@ typedef void (*psPeerContextFree)(freerdp_peer* client, rdpContext* context); typedef BOOL (*psPeerInitialize)(freerdp_peer* client); typedef BOOL (*psPeerGetFileDescriptor)(freerdp_peer* client, void** rfds, int* rcount); typedef HANDLE (*psPeerGetEventHandle)(freerdp_peer* client); +typedef HANDLE (*psPeerGetReceiveEventHandle)(freerdp_peer* client); typedef BOOL (*psPeerCheckFileDescriptor)(freerdp_peer* client); +typedef BOOL (*psPeerIsWriteBlocked)(freerdp_peer* client); +typedef int (*psPeerDrainOutputBuffer)(freerdp_peer* client); typedef BOOL (*psPeerClose)(freerdp_peer* client); typedef void (*psPeerDisconnect)(freerdp_peer* client); typedef BOOL (*psPeerCapabilities)(freerdp_peer* client); @@ -62,6 +65,7 @@ struct rdp_freerdp_peer psPeerInitialize Initialize; psPeerGetFileDescriptor GetFileDescriptor; psPeerGetEventHandle GetEventHandle; + psPeerGetReceiveEventHandle GetReceiveEventHandle; psPeerCheckFileDescriptor CheckFileDescriptor; psPeerClose Close; psPeerDisconnect Disconnect; @@ -81,6 +85,9 @@ struct rdp_freerdp_peer BOOL activated; BOOL authenticated; SEC_WINNT_AUTH_IDENTITY identity; + + psPeerIsWriteBlocked IsWriteBlocked; + psPeerDrainOutputBuffer DrainOutputBuffer; }; #ifdef __cplusplus diff --git a/include/freerdp/settings.h b/include/freerdp/settings.h index 6e921eb21..d73ccf22b 100644 --- a/include/freerdp/settings.h +++ b/include/freerdp/settings.h @@ -597,6 +597,7 @@ typedef struct _RDPDR_PARALLEL RDPDR_PARALLEL; #define FreeRDP_RestrictedAdminModeRequired 1097 #define FreeRDP_AuthenticationServiceClass 1098 #define FreeRDP_DisableCredentialsDelegation 1099 +#define FreeRDP_AuthenticationLevel 1100 #define FreeRDP_MstscCookieMode 1152 #define FreeRDP_CookieMaxLength 1153 #define FreeRDP_PreconnectionId 1154 @@ -798,7 +799,8 @@ struct rdp_settings ALIGN64 char* Password; /* 22 */ ALIGN64 char* Domain; /* 23 */ ALIGN64 char* PasswordHash; /* 24 */ - UINT64 padding0064[64 - 25]; /* 25 */ + ALIGN64 BOOL WaitForOutputBufferFlush; /* 25 */ + UINT64 padding0064[64 - 26]; /* 26 */ UINT64 padding0128[128 - 64]; /* 64 */ /** @@ -952,7 +954,8 @@ struct rdp_settings ALIGN64 BOOL RestrictedAdminModeRequired; /* 1097 */ ALIGN64 char* AuthenticationServiceClass; /* 1098 */ ALIGN64 BOOL DisableCredentialsDelegation; /* 1099 */ - UINT64 padding1152[1152 - 1100]; /* 1100 */ + ALIGN64 BOOL AuthenticationLevel; /* 1100 */ + UINT64 padding1152[1152 - 1101]; /* 1101 */ /* Connection Cookie */ ALIGN64 BOOL MstscCookieMode; /* 1152 */ diff --git a/include/freerdp/utils/ringbuffer.h b/include/freerdp/utils/ringbuffer.h new file mode 100644 index 000000000..dfe9e8350 --- /dev/null +++ b/include/freerdp/utils/ringbuffer.h @@ -0,0 +1,128 @@ +/** + * FreeRDP: A Remote Desktop Protocol Implementation + * + * Copyright 2014 Thincast Technologies GmbH + * Copyright 2014 Hardening + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __RINGBUFFER_H___ +#define __RINGBUFFER_H___ + +#include +#include + + +/** @brief ring buffer meta data */ +struct _RingBuffer { + size_t initialSize; + size_t freeSize; + size_t size; + size_t readPtr; + size_t writePtr; + BYTE *buffer; +}; +typedef struct _RingBuffer RingBuffer; + + +/** @brief a piece of data in the ring buffer, exactly like a glibc iovec */ +struct _DataChunk { + size_t size; + const BYTE *data; +}; +typedef struct _DataChunk DataChunk; + +#ifdef __cplusplus +extern "C" { +#endif + +/** initialise a ringbuffer + * @param initialSize the initial capacity of the ringBuffer + * @return if the initialisation was successful + */ +FREERDP_API BOOL ringbuffer_init(RingBuffer *rb, size_t initialSize); + +/** destroys internal data used by this ringbuffer + * @param ringbuffer + */ +FREERDP_API void ringbuffer_destroy(RingBuffer *ringbuffer); + +/** computes the space used in this ringbuffer + * @param ringbuffer + * @return the number of bytes stored in that ringbuffer + */ +FREERDP_API size_t ringbuffer_used(const RingBuffer *ringbuffer); + +/** returns the capacity of the ring buffer + * @param ringbuffer + * @return the capacity of this ring buffer + */ +FREERDP_API size_t ringbuffer_capacity(const RingBuffer *ringbuffer); + +/** writes some bytes in the ringbuffer, if the data doesn't fit, the ringbuffer + * is resized automatically + * + * @param rb the ringbuffer + * @param ptr a pointer on the data to add + * @param sz the size of the data to add + * @return if the operation was successful, it could fail in case of OOM during realloc() + */ +FREERDP_API BOOL ringbuffer_write(RingBuffer *rb, const BYTE *ptr, size_t sz); + + +/** ensures that we have sz bytes available at the write head, and return a pointer + * on the write head + * + * @param rb the ring buffer + * @param sz the size to ensure + * @return a pointer on the write head, or NULL in case of OOM + */ +FREERDP_API BYTE *ringbuffer_ensure_linear_write(RingBuffer *rb, size_t sz); + +/** move ahead the write head in case some byte were written directly by using + * a pointer retrieved via ringbuffer_ensure_linear_write(). This function is + * used to commit the written bytes. The provided size should not exceed the + * size ensured by ringbuffer_ensure_linear_write() + * + * @param rb the ring buffer + * @param sz the number of bytes that have been written + * @return if the operation was successful, FALSE is sz is too big + */ +FREERDP_API BOOL ringbuffer_commit_written_bytes(RingBuffer *rb, size_t sz); + + +/** peeks the buffer chunks for sz bytes and returns how many chunks are filled. + * Note that the sum of the resulting chunks may be smaller than sz. + * + * @param rb the ringbuffer + * @param chunks an array of data chunks that will contain data / size of chunks + * @param sz the requested size + * @return the number of chunks used for reading sz bytes + */ +FREERDP_API int ringbuffer_peek(const RingBuffer *rb, DataChunk chunks[2], size_t sz); + +/** move ahead the read head in case some byte were read using ringbuffer_peek() + * This function is used to commit the bytes that were effectively consumed. + * + * @param rb the ring buffer + * @param sz the + */ +FREERDP_API void ringbuffer_commit_read_bytes(RingBuffer *rb, size_t sz); + + +#ifdef __cplusplus +} +#endif + +#endif /* __RINGBUFFER_H___ */ diff --git a/libfreerdp/codec/region.c b/libfreerdp/codec/region.c index 4ad0213f4..aee2e3036 100644 --- a/libfreerdp/codec/region.c +++ b/libfreerdp/codec/region.c @@ -1,24 +1,20 @@ /** - * Copyright © 2014 Thincast Technologies GmbH - * Copyright © 2014 Hardening + * FreeRDP: A Remote Desktop Protocol Implementation * - * Permission to use, copy, modify, distribute, and sell this software and - * its documentation for any purpose is hereby granted without fee, provided - * that the above copyright notice appear in all copies and that both that - * copyright notice and this permission notice appear in supporting - * documentation, and that the name of the copyright holders not be used in - * advertising or publicity pertaining to distribution of the software - * without specific, written prior permission. The copyright holders make - * no representations about the suitability of this software for any - * purpose. It is provided "as is" without express or implied warranty. + * Copyright 2014 Thincast Technologies GmbH + * Copyright 2014 Hardening * - * THE COPYRIGHT HOLDERS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS - * SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS, IN NO EVENT SHALL THE COPYRIGHT HOLDERS BE LIABLE FOR ANY - * SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER - * RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF - * CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN - * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include diff --git a/libfreerdp/codec/test/TestFreeRDPRegion.c b/libfreerdp/codec/test/TestFreeRDPRegion.c index 80ff54729..1b39b11dd 100644 --- a/libfreerdp/codec/test/TestFreeRDPRegion.c +++ b/libfreerdp/codec/test/TestFreeRDPRegion.c @@ -1,24 +1,20 @@ /** - * Copyright © 2014 Thincast Technologies GmbH - * Copyright © 2014 Hardening + * FreeRDP: A Remote Desktop Protocol Implementation * - * Permission to use, copy, modify, distribute, and sell this software and - * its documentation for any purpose is hereby granted without fee, provided - * that the above copyright notice appear in all copies and that both that - * copyright notice and this permission notice appear in supporting - * documentation, and that the name of the copyright holders not be used in - * advertising or publicity pertaining to distribution of the software - * without specific, written prior permission. The copyright holders make - * no representations about the suitability of this software for any - * purpose. It is provided "as is" without express or implied warranty. + * Copyright 2014 Thincast Technologies GmbH + * Copyright 2014 Hardening * - * THE COPYRIGHT HOLDERS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS - * SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS, IN NO EVENT SHALL THE COPYRIGHT HOLDERS BE LIABLE FOR ANY - * SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER - * RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF - * CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN - * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include diff --git a/libfreerdp/common/settings.c b/libfreerdp/common/settings.c index 6947895ad..4f28c8692 100644 --- a/libfreerdp/common/settings.c +++ b/libfreerdp/common/settings.c @@ -822,6 +822,10 @@ BOOL freerdp_get_param_bool(rdpSettings* settings, int id) return settings->DisableCredentialsDelegation; break; + case FreeRDP_AuthenticationLevel: + return settings->AuthenticationLevel; + break; + case FreeRDP_MstscCookieMode: return settings->MstscCookieMode; break; @@ -874,6 +878,10 @@ BOOL freerdp_get_param_bool(rdpSettings* settings, int id) return settings->AsyncChannels; break; + case FreeRDP_AsyncTransport: + return settings->AsyncTransport; + break; + case FreeRDP_ToggleFullscreen: return settings->ToggleFullscreen; break; @@ -1091,6 +1099,7 @@ BOOL freerdp_get_param_bool(rdpSettings* settings, int id) break; default: + fprintf(stderr, "freerdp_get_param_bool: unknown id: %d\n", id); return -1; break; } @@ -1298,6 +1307,10 @@ int freerdp_set_param_bool(rdpSettings* settings, int id, BOOL param) settings->DisableCredentialsDelegation = param; break; + case FreeRDP_AuthenticationLevel: + settings->AuthenticationLevel = param; + break; + case FreeRDP_MstscCookieMode: settings->MstscCookieMode = param; break; @@ -1350,6 +1363,10 @@ int freerdp_set_param_bool(rdpSettings* settings, int id, BOOL param) settings->AsyncChannels = param; break; + case FreeRDP_AsyncTransport: + settings->AsyncTransport = param; + break; + case FreeRDP_ToggleFullscreen: settings->ToggleFullscreen = param; break; @@ -1567,6 +1584,7 @@ int freerdp_set_param_bool(rdpSettings* settings, int id, BOOL param) break; default: + fprintf(stderr, "freerdp_set_param_bool: unknown id %d (param = %d)\n", id, param); return -1; break; } @@ -1590,6 +1608,7 @@ int freerdp_get_param_int(rdpSettings* settings, int id) break; default: + fprintf(stderr, "freerdp_get_param_int: unknown id: %d\n", id); return 0; break; } @@ -1610,6 +1629,7 @@ int freerdp_set_param_int(rdpSettings* settings, int id, int param) break; default: + fprintf(stderr, "freerdp_set_param_int: unknown id %d (param = %d)\n", id, param); return -1; break; } @@ -1920,6 +1940,7 @@ UINT32 freerdp_get_param_uint32(rdpSettings* settings, int id) break; default: + fprintf(stderr, "freerdp_get_param_uint32: unknown id: %d\n", id); return 0; break; } @@ -2228,6 +2249,7 @@ int freerdp_set_param_uint32(rdpSettings* settings, int id, UINT32 param) break; default: + fprintf(stderr, "freerdp_set_param_uint32: unknown id %d (param = %u)\n", id, param); return -1; break; } @@ -2247,6 +2269,7 @@ UINT64 freerdp_get_param_uint64(rdpSettings* settings, int id) break; default: + fprintf(stderr, "freerdp_get_param_uint64: unknown id: %d\n", id); return -1; break; } @@ -2263,6 +2286,7 @@ int freerdp_set_param_uint64(rdpSettings* settings, int id, UINT64 param) break; default: + fprintf(stderr, "freerdp_set_param_uint64: unknown id %d (param = %u)\n", id, (UINT32) param); return -1; break; } @@ -2438,6 +2462,7 @@ char* freerdp_get_param_string(rdpSettings* settings, int id) break; default: + fprintf(stderr, "freerdp_get_param_string: unknown id: %d\n", id); return NULL; break; } @@ -2650,6 +2675,7 @@ int freerdp_set_param_string(rdpSettings* settings, int id, const char* param) break; default: + fprintf(stderr, "freerdp_set_param_string: unknown id %d (param = %s)\n", id, param); return -1; break; } @@ -2669,6 +2695,7 @@ double freerdp_get_param_double(rdpSettings* settings, int id) break; default: + fprintf(stderr, "freerdp_get_param_double: unknown id: %d\n", id); return 0; break; } diff --git a/libfreerdp/core/client.c b/libfreerdp/core/client.c index c5dd47c29..cc13e28f8 100644 --- a/libfreerdp/core/client.c +++ b/libfreerdp/core/client.c @@ -55,7 +55,7 @@ CHANNEL_OPEN_DATA* freerdp_channels_find_channel_open_data_by_name(rdpChannels* /* returns rdpChannel for the channel name passed in */ rdpMcsChannel* freerdp_channels_find_channel_by_name(rdpRdp* rdp, const char* name) { - int index; + UINT32 index; rdpMcsChannel* channel; rdpMcs* mcs = rdp->mcs; @@ -221,7 +221,7 @@ int freerdp_channels_post_connect(rdpChannels* channels, freerdp* instance) int freerdp_channels_data(freerdp* instance, UINT16 channelId, BYTE* data, int dataSize, int flags, int totalSize) { - int index; + UINT32 index; rdpMcs* mcs; rdpChannels* channels; rdpMcsChannel* channel = NULL; diff --git a/libfreerdp/core/gateway/http.c b/libfreerdp/core/gateway/http.c index c9f33f01a..610b23091 100644 --- a/libfreerdp/core/gateway/http.c +++ b/libfreerdp/core/gateway/http.c @@ -26,6 +26,10 @@ #include #include +#ifdef HAVE_VALGRIND_MEMCHECK_H +#include +#endif + #include "http.h" HttpContext* http_context_new() @@ -472,7 +476,7 @@ HttpResponse* http_response_recv(rdpTls* tls) nbytes = 0; length = 10000; content = NULL; - buffer = malloc(length); + buffer = calloc(length, 1); if (!buffer) return NULL; @@ -487,14 +491,20 @@ HttpResponse* http_response_recv(rdpTls* tls) { while (nbytes < 5) { - status = tls_read(tls, p, length - nbytes); + status = BIO_read(tls->bio, p, length - nbytes); - if (status < 0) - goto out_error; + if (status <= 0) + { + if (!BIO_should_retry(tls->bio)) + goto out_error; - if (!status) + USleep(100); continue; + } +#ifdef HAVE_VALGRIND_MEMCHECK_H + VALGRIND_MAKE_MEM_DEFINED(p, status); +#endif nbytes += status; p = (BYTE*) &buffer[nbytes]; } @@ -503,7 +513,7 @@ HttpResponse* http_response_recv(rdpTls* tls) if (!header_end) { - fprintf(stderr, "http_response_recv: invalid response:\n"); + fprintf(stderr, "%s: invalid response:\n", __FUNCTION__); winpr_HexDump(buffer, status); goto out_error; } @@ -517,7 +527,7 @@ HttpResponse* http_response_recv(rdpTls* tls) header_end[0] = '\0'; header_end[1] = '\0'; - content = &header_end[2]; + content = header_end + 2; count = 0; line = (char*) buffer; @@ -552,11 +562,14 @@ HttpResponse* http_response_recv(rdpTls* tls) if (!http_response_parse_header(http_response)) goto out_error; - if (http_response->ContentLength > 0) + http_response->bodyLen = nbytes - (content - (char *)buffer); + if (http_response->bodyLen > 0) { - http_response->Content = _strdup(content); - if (!http_response->Content) + http_response->BodyContent = (BYTE *)malloc(http_response->bodyLen); + if (!http_response->BodyContent) goto out_error; + + CopyMemory(http_response->BodyContent, content, http_response->bodyLen); } break; @@ -627,7 +640,7 @@ void http_response_free(HttpResponse* http_response) ListDictionary_Free(http_response->Authenticates); if (http_response->ContentLength > 0) - free(http_response->Content); + free(http_response->BodyContent); free(http_response); } diff --git a/libfreerdp/core/gateway/http.h b/libfreerdp/core/gateway/http.h index 748b45a36..ded9ba214 100644 --- a/libfreerdp/core/gateway/http.h +++ b/libfreerdp/core/gateway/http.h @@ -84,7 +84,8 @@ struct _http_response wListDictionary *Authenticates; int ContentLength; - char* Content; + BYTE *BodyContent; + int bodyLen; }; void http_response_print(HttpResponse* http_response); diff --git a/libfreerdp/core/gateway/ncacn_http.c b/libfreerdp/core/gateway/ncacn_http.c index 270dafbcf..b5beff4b2 100644 --- a/libfreerdp/core/gateway/ncacn_http.c +++ b/libfreerdp/core/gateway/ncacn_http.c @@ -98,6 +98,8 @@ int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc) rdpNtlm* ntlm = rpc->NtlmHttpIn->ntlm; http_response = http_response_recv(rpc->TlsIn); + if (!http_response) + return -1; if (ListDictionary_Contains(http_response->Authenticates, "NTLM")) { @@ -105,14 +107,12 @@ int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc) if (!token64) goto out; - ntlm_token_data = NULL; crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length); } +out: ntlm->inputBuffer[0].pvBuffer = ntlm_token_data; ntlm->inputBuffer[0].cbBuffer = ntlm_token_length; - -out: http_response_free(http_response); return 0; @@ -123,25 +123,19 @@ int rpc_ncacn_http_ntlm_init(rdpRpc* rpc, TSG_CHANNEL channel) rdpNtlm* ntlm = NULL; rdpSettings* settings = rpc->settings; freerdp* instance = (freerdp*) rpc->settings->instance; - BOOL promptPassword = FALSE; if (channel == TSG_CHANNEL_IN) ntlm = rpc->NtlmHttpIn->ntlm; else if (channel == TSG_CHANNEL_OUT) ntlm = rpc->NtlmHttpOut->ntlm; - if ((!settings->GatewayPassword) || (!settings->GatewayUsername) - || (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername))) - { - promptPassword = TRUE; - } - - if (promptPassword) + if (!settings->GatewayPassword || !settings->GatewayUsername || + !strlen(settings->GatewayPassword) || !strlen(settings->GatewayUsername)) { if (instance->GatewayAuthenticate) { - BOOL proceed = instance->GatewayAuthenticate(instance, - &settings->GatewayUsername, &settings->GatewayPassword, &settings->GatewayDomain); + BOOL proceed = instance->GatewayAuthenticate(instance, &settings->GatewayUsername, + &settings->GatewayPassword, &settings->GatewayDomain); if (!proceed) { @@ -240,12 +234,10 @@ int rpc_ncacn_http_recv_out_channel_response(rdpRpc* rpc) char *token64 = ListDictionary_GetItemValue(http_response->Authenticates, "NTLM"); crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length); } - ntlm->inputBuffer[0].pvBuffer = ntlm_token_data; ntlm->inputBuffer[0].cbBuffer = ntlm_token_length; - + http_response_free(http_response); - return 0; } @@ -259,15 +251,12 @@ BOOL rpc_ntlm_http_out_connect(rdpRpc* rpc) success = TRUE; /* Send OUT Channel Request */ - rpc_ncacn_http_send_out_channel_request(rpc); /* Receive OUT Channel Response */ - rpc_ncacn_http_recv_out_channel_response(rpc); /* Send OUT Channel Request */ - rpc_ncacn_http_send_out_channel_request(rpc); ntlm_client_uninit(ntlm); @@ -296,13 +285,11 @@ void rpc_ntlm_http_init_channel(rdpRpc* rpc, rdpNtlmHttp* ntlm_http, TSG_CHANNEL if (channel == TSG_CHANNEL_IN) { - http_context_set_pragma(ntlm_http->context, - "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729"); + http_context_set_pragma(ntlm_http->context, "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729"); } else if (channel == TSG_CHANNEL_OUT) { - http_context_set_pragma(ntlm_http->context, - "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729" ", " + http_context_set_pragma(ntlm_http->context, "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729, " "SessionId=fbd9c34f-397d-471d-a109-1b08cc554624"); } } diff --git a/libfreerdp/core/gateway/rpc.c b/libfreerdp/core/gateway/rpc.c index c91a71071..2432ab06c 100644 --- a/libfreerdp/core/gateway/rpc.c +++ b/libfreerdp/core/gateway/rpc.c @@ -33,6 +33,11 @@ #include #include +#include + +#ifdef HAVE_VALGRIND_MEMCHECK_H +#include +#endif #include "http.h" #include "ntlm.h" @@ -235,80 +240,77 @@ BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* buffer, UINT32* offset, UINT32* l { UINT32 alloc_hint = 0; rpcconn_hdr_t* header; + UINT32 frag_length; + UINT32 auth_length; + UINT32 auth_pad_length; + UINT32 sec_trailer_offset; + rpc_sec_trailer* sec_trailer; *offset = RPC_COMMON_FIELDS_LENGTH; header = ((rpcconn_hdr_t*) buffer); - if (header->common.ptype == PTYPE_RESPONSE) + switch (header->common.ptype) { - *offset += 8; - rpc_offset_align(offset, 8); - alloc_hint = header->response.alloc_hint; - } - else if (header->common.ptype == PTYPE_REQUEST) - { - *offset += 4; - rpc_offset_align(offset, 8); - alloc_hint = header->request.alloc_hint; - } - else if (header->common.ptype == PTYPE_RTS) - { - *offset += 4; - } - else - { - return FALSE; + case PTYPE_RESPONSE: + *offset += 8; + rpc_offset_align(offset, 8); + alloc_hint = header->response.alloc_hint; + break; + case PTYPE_REQUEST: + *offset += 4; + rpc_offset_align(offset, 8); + alloc_hint = header->request.alloc_hint; + break; + case PTYPE_RTS: + *offset += 4; + break; + default: + fprintf(stderr, "%s: unknown ptype=0x%x\n", __FUNCTION__, header->common.ptype); + return FALSE; } - if (length) + if (!length) + return TRUE; + + if (header->common.ptype == PTYPE_REQUEST) { - if (header->common.ptype == PTYPE_REQUEST) - { - UINT32 sec_trailer_offset; + UINT32 sec_trailer_offset; - sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8; - *length = sec_trailer_offset - *offset; - } - else - { - UINT32 frag_length; - UINT32 auth_length; - UINT32 auth_pad_length; - UINT32 sec_trailer_offset; - rpc_sec_trailer* sec_trailer; + sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8; + *length = sec_trailer_offset - *offset; + return TRUE; + } - frag_length = header->common.frag_length; - auth_length = header->common.auth_length; - sec_trailer_offset = frag_length - auth_length - 8; - sec_trailer = (rpc_sec_trailer*) &buffer[sec_trailer_offset]; - auth_pad_length = sec_trailer->auth_pad_length; + frag_length = header->common.frag_length; + auth_length = header->common.auth_length; + + sec_trailer_offset = frag_length - auth_length - 8; + sec_trailer = (rpc_sec_trailer*) &buffer[sec_trailer_offset]; + auth_pad_length = sec_trailer->auth_pad_length; #if 0 - fprintf(stderr, "sec_trailer: type: %d level: %d pad_length: %d reserved: %d context_id: %d\n", - sec_trailer->auth_type, - sec_trailer->auth_level, - sec_trailer->auth_pad_length, - sec_trailer->auth_reserved, - sec_trailer->auth_context_id); + fprintf(stderr, "sec_trailer: type: %d level: %d pad_length: %d reserved: %d context_id: %d\n", + sec_trailer->auth_type, + sec_trailer->auth_level, + sec_trailer->auth_pad_length, + sec_trailer->auth_reserved, + sec_trailer->auth_context_id); #endif - /** - * According to [MS-RPCE], auth_pad_length is the number of padding - * octets used to 4-byte align the security trailer, but in practice - * we get values up to 15, which indicates 16-byte alignment. - */ + /** + * According to [MS-RPCE], auth_pad_length is the number of padding + * octets used to 4-byte align the security trailer, but in practice + * we get values up to 15, which indicates 16-byte alignment. + */ - if ((frag_length - (sec_trailer_offset + 8)) != auth_length) - { - fprintf(stderr, "invalid auth_length: actual: %d, expected: %d\n", auth_length, - (frag_length - (sec_trailer_offset + 8))); - } - - *length = frag_length - auth_length - 24 - 8 - auth_pad_length; - } + if ((frag_length - (sec_trailer_offset + 8)) != auth_length) + { + fprintf(stderr, "invalid auth_length: actual: %d, expected: %d\n", auth_length, + (frag_length - (sec_trailer_offset + 8))); } + *length = frag_length - auth_length - 24 - 8 - auth_pad_length; return TRUE; } @@ -316,12 +318,23 @@ int rpc_out_read(rdpRpc* rpc, BYTE* data, int length) { int status; - status = tls_read(rpc->TlsOut, data, length); + status = BIO_read(rpc->TlsOut->bio, data, length); + /* fprintf(stderr, "%s: length=%d => status=%d shouldRetry=%d\n", __FUNCTION__, length, + * status, BIO_should_retry(rpc->TlsOut->bio)); */ + if (status > 0) { +#ifdef HAVE_VALGRIND_MEMCHECK_H + VALGRIND_MAKE_MEM_DEFINED(data, status); +#endif + return status; + } - return status; + if (BIO_should_retry(rpc->TlsOut->bio)) + return 0; + + return -1; } -int rpc_out_write(rdpRpc* rpc, BYTE* data, int length) +int rpc_out_write(rdpRpc* rpc, const BYTE* data, int length) { int status; @@ -330,7 +343,7 @@ int rpc_out_write(rdpRpc* rpc, BYTE* data, int length) return status; } -int rpc_in_write(rdpRpc* rpc, BYTE* data, int length) +int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length) { int status; @@ -360,20 +373,21 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) ntlm = rpc->ntlm; - if ((!ntlm) || (!ntlm->table)) + if (!ntlm || !ntlm->table) { - fprintf(stderr, "rpc_write: invalid ntlm context\n"); + fprintf(stderr, "%s: invalid ntlm context\n", __FUNCTION__); return -1; } if (ntlm->table->QueryContextAttributes(&ntlm->context, SECPKG_ATTR_SIZES, &ntlm->ContextSizes) != SEC_E_OK) { - fprintf(stderr, "QueryContextAttributes SECPKG_ATTR_SIZES failure\n"); + fprintf(stderr, "%s: QueryContextAttributes SECPKG_ATTR_SIZES failure\n", __FUNCTION__); return -1; } - request_pdu = (rpcconn_request_hdr_t*) malloc(sizeof(rpcconn_request_hdr_t)); - ZeroMemory(request_pdu, sizeof(rpcconn_request_hdr_t)); + request_pdu = (rpcconn_request_hdr_t*) calloc(1, sizeof(rpcconn_request_hdr_t)); + if (!request_pdu) + return -1; rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) request_pdu); @@ -386,7 +400,11 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) request_pdu->opnum = opnum; clientCall = rpc_client_call_new(request_pdu->call_id, request_pdu->opnum); - ArrayList_Add(rpc->client->ClientCallList, clientCall); + if (!clientCall) + goto out_free_pdu; + + if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0) + goto out_free_clientCall; if (request_pdu->opnum == TsProxySetupReceivePipeOpnum) rpc->PipeCallId = request_pdu->call_id; @@ -407,8 +425,9 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) request_pdu->frag_length = offset; - buffer = (BYTE*) malloc(request_pdu->frag_length); - + buffer = (BYTE*) calloc(1, request_pdu->frag_length); + if (!buffer) + goto out_free_pdu; CopyMemory(buffer, request_pdu, 24); offset = 24; @@ -427,15 +446,15 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) Buffers[0].cbBuffer = offset; Buffers[1].cbBuffer = ntlm->ContextSizes.cbMaxSignature; - Buffers[1].pvBuffer = malloc(Buffers[1].cbBuffer); - ZeroMemory(Buffers[1].pvBuffer, Buffers[1].cbBuffer); + Buffers[1].pvBuffer = calloc(1, Buffers[1].cbBuffer); + if (!Buffers[1].pvBuffer) + return -1; Message.cBuffers = 2; Message.ulVersion = SECBUFFER_VERSION; Message.pBuffers = (PSecBuffer) &Buffers; encrypt_status = ntlm->table->EncryptMessage(&ntlm->context, 0, &Message, rpc->SendSeqNum++); - if (encrypt_status != SEC_E_OK) { fprintf(stderr, "EncryptMessage status: 0x%08X\n", encrypt_status); @@ -447,12 +466,18 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) offset += Buffers[1].cbBuffer; free(Buffers[1].pvBuffer); - if (rpc_send_enqueue_pdu(rpc, buffer, request_pdu->frag_length) != 0) + if (rpc_send_enqueue_pdu(rpc, buffer, request_pdu->frag_length) < 0) length = -1; free(request_pdu); return length; + +out_free_clientCall: + rpc_client_call_free(clientCall); +out_free_pdu: + free(request_pdu); + return -1; } BOOL rpc_connect(rdpRpc* rpc) @@ -592,13 +617,17 @@ rdpRpc* rpc_new(rdpTransport* transport) rpc->CallId = 2; - rpc_client_new(rpc); + if (rpc_client_new(rpc) < 0) + goto out_free_virtualConnectionCookieTable; rpc->client->SynchronousSend = TRUE; rpc->client->SynchronousReceive = TRUE; return rpc; +out_free_virtualConnectionCookieTable: + rpc_client_free(rpc); + ArrayList_Free(rpc->VirtualConnectionCookieTable); out_free_virtual_connection: rpc_client_virtual_connection_free(rpc->VirtualConnection); out_free_ntlm_http_out: diff --git a/libfreerdp/core/gateway/rpc.h b/libfreerdp/core/gateway/rpc.h index d10d665c7..c86a8618f 100644 --- a/libfreerdp/core/gateway/rpc.h +++ b/libfreerdp/core/gateway/rpc.h @@ -772,8 +772,8 @@ UINT32 rpc_offset_pad(UINT32* offset, UINT32 pad); int rpc_out_read(rdpRpc* rpc, BYTE* data, int length); -int rpc_out_write(rdpRpc* rpc, BYTE* data, int length); -int rpc_in_write(rdpRpc* rpc, BYTE* data, int length); +int rpc_out_write(rdpRpc* rpc, const BYTE* data, int length); +int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length); BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* header, UINT32* offset, UINT32* length); diff --git a/libfreerdp/core/gateway/rpc_bind.c b/libfreerdp/core/gateway/rpc_bind.c index cf02a802a..ceae95159 100644 --- a/libfreerdp/core/gateway/rpc_bind.c +++ b/libfreerdp/core/gateway/rpc_bind.c @@ -103,6 +103,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc) DEBUG_RPC("Sending bind PDU"); rpc->ntlm = ntlm_new(); + if (!rpc->ntlm) + return -1; if ((!settings->GatewayPassword) || (!settings->GatewayUsername) || (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername))) @@ -129,17 +131,22 @@ int rpc_send_bind_pdu(rdpRpc* rpc) settings->Username = _strdup(settings->GatewayUsername); settings->Domain = _strdup(settings->GatewayDomain); settings->Password = _strdup(settings->GatewayPassword); + + if (!settings->Username || !settings->Domain || settings->Password) + return -1; } } } - ntlm_client_init(rpc->ntlm, FALSE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, NULL); - ntlm_client_make_spn(rpc->ntlm, NULL, settings->GatewayHostname); + if (!ntlm_client_init(rpc->ntlm, FALSE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, NULL) || + !ntlm_client_make_spn(rpc->ntlm, NULL, settings->GatewayHostname) || + !ntlm_authenticate(rpc->ntlm) + ) + return -1; - ntlm_authenticate(rpc->ntlm); - - bind_pdu = (rpcconn_bind_hdr_t*) malloc(sizeof(rpcconn_bind_hdr_t)); - ZeroMemory(bind_pdu, sizeof(rpcconn_bind_hdr_t)); + bind_pdu = (rpcconn_bind_hdr_t*) calloc(1, sizeof(rpcconn_bind_hdr_t)); + if (!bind_pdu) + return -1; rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) bind_pdu); @@ -159,6 +166,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc) bind_pdu->p_context_elem.reserved2 = 0; bind_pdu->p_context_elem.p_cont_elem = malloc(sizeof(p_cont_elem_t) * bind_pdu->p_context_elem.n_context_elem); + if (!bind_pdu->p_context_elem.p_cont_elem) + return -1; p_cont_elem = &bind_pdu->p_context_elem.p_cont_elem[0]; @@ -196,6 +205,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc) bind_pdu->frag_length = offset; buffer = (BYTE*) malloc(bind_pdu->frag_length); + if (!buffer) + return -1; CopyMemory(buffer, bind_pdu, 24); CopyMemory(&buffer[24], &bind_pdu->p_context_elem, 4); @@ -214,7 +225,10 @@ int rpc_send_bind_pdu(rdpRpc* rpc) length = bind_pdu->frag_length; clientCall = rpc_client_call_new(bind_pdu->call_id, 0); - ArrayList_Add(rpc->client->ClientCallList, clientCall); + if (!clientCall) + return -1; + if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0) + return -1; if (rpc_send_enqueue_pdu(rpc, buffer, length) != 0) length = -1; diff --git a/libfreerdp/core/gateway/rpc_client.c b/libfreerdp/core/gateway/rpc_client.c index dff88b3e5..c3613f6be 100644 --- a/libfreerdp/core/gateway/rpc_client.c +++ b/libfreerdp/core/gateway/rpc_client.c @@ -34,9 +34,7 @@ #include #include "rpc_fault.h" - #include "rpc_client.h" - #include "../rdp.h" #define SYNCHRONOUS_TIMEOUT 5000 @@ -69,8 +67,15 @@ RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc) if (!pdu) { - pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU)); + pdu = (RPC_PDU *)malloc(sizeof(RPC_PDU)); + if (!pdu) + return NULL; pdu->s = Stream_New(NULL, rpc->max_recv_frag); + if (!pdu->s) + { + free(pdu); + return NULL; + } } pdu->CallId = 0; @@ -84,8 +89,7 @@ RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc) int rpc_client_receive_pool_return(rdpRpc* rpc, RPC_PDU* pdu) { - Queue_Enqueue(rpc->client->ReceivePool, pdu); - return 0; + return Queue_Enqueue(rpc->client->ReceivePool, pdu) == TRUE ? 0 : -1; } int rpc_client_on_fragment_received_event(rdpRpc* rpc) @@ -97,7 +101,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc) rpcconn_hdr_t* header; freerdp* instance; - instance = (freerdp*) rpc->transport->settings->instance; + instance = (freerdp *)rpc->transport->settings->instance; if (!rpc->client->pdu) rpc->client->pdu = rpc_client_receive_pool_take(rpc); @@ -125,34 +129,29 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc) return 0; } - if (header->common.ptype == PTYPE_RTS) + switch (header->common.ptype) { - if (rpc->VirtualConnection->State >= VIRTUAL_CONNECTION_STATE_OPENED) - { - //fprintf(stderr, "Receiving Out-of-Sequence RTS PDU\n"); + case PTYPE_RTS: + if (rpc->VirtualConnection->State < VIRTUAL_CONNECTION_STATE_OPENED) + { + fprintf(stderr, "%s: warning: unhandled RTS PDU\n", __FUNCTION__); + return 0; + } + fprintf(stderr, "%s: Receiving Out-of-Sequence RTS PDU\n", __FUNCTION__); rts_recv_out_of_sequence_pdu(rpc, buffer, header->common.frag_length); - rpc_client_fragment_pool_return(rpc, fragment); - } - else - { - fprintf(stderr, "warning: unhandled RTS PDU\n"); - } + return 0; - return 0; - } - else if (header->common.ptype == PTYPE_FAULT) - { - rpc_recv_fault_pdu(header); - Queue_Enqueue(rpc->client->ReceiveQueue, NULL); - return -1; - } - - if (header->common.ptype != PTYPE_RESPONSE) - { - fprintf(stderr, "Unexpected RPC PDU type: %d\n", header->common.ptype); - Queue_Enqueue(rpc->client->ReceiveQueue, NULL); - return -1; + case PTYPE_FAULT: + rpc_recv_fault_pdu(header); + Queue_Enqueue(rpc->client->ReceiveQueue, NULL); + return -1; + case PTYPE_RESPONSE: + break; + default: + fprintf(stderr, "%s: unexpected RPC PDU type %d\n", __FUNCTION__, header->common.ptype); + Queue_Enqueue(rpc->client->ReceiveQueue, NULL); + return -1; } rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header->common.frag_length; @@ -160,7 +159,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc) if (!rpc_get_stub_data_info(rpc, buffer, &StubOffset, &StubLength)) { - fprintf(stderr, "rpc_recv_pdu_fragment: expected stub\n"); + fprintf(stderr, "%s: expected stub\n", __FUNCTION__); Queue_Enqueue(rpc->client->ReceiveQueue, NULL); return -1; } @@ -196,7 +195,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc) if (rpc->StubCallId != header->common.call_id) { - fprintf(stderr, "invalid call_id: actual: %d, expected: %d, frag_count: %d\n", + fprintf(stderr, "%s: invalid call_id: actual: %d, expected: %d, frag_count: %d\n", __FUNCTION__, rpc->StubCallId, header->common.call_id, rpc->StubFragCount); } @@ -243,27 +242,34 @@ int rpc_client_on_read_event(rdpRpc* rpc) int status = -1; rpcconn_common_hdr_t* header; - if (!rpc->client->RecvFrag) - rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc); - - position = Stream_GetPosition(rpc->client->RecvFrag); - - if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH) + while (1) { - status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag), - RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag)); + if (!rpc->client->RecvFrag) + rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc); - if (status < 0) + position = Stream_GetPosition(rpc->client->RecvFrag); + + while (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH) { - fprintf(stderr, "rpc_client_frag_read: error reading header\n"); - return -1; + status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag), + RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag)); + + if (status < 0) + { + fprintf(stderr, "rpc_client_frag_read: error reading header\n"); + return -1; + } + + if (!status) + return 0; + + Stream_Seek(rpc->client->RecvFrag, status); } - Stream_Seek(rpc->client->RecvFrag, status); - } + if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH) + return status; + - if (Stream_GetPosition(rpc->client->RecvFrag) >= RPC_COMMON_FIELDS_LENGTH) - { header = (rpcconn_common_hdr_t*) Stream_Buffer(rpc->client->RecvFrag); if (header->frag_length > rpc->max_recv_frag) @@ -274,45 +280,44 @@ int rpc_client_on_read_event(rdpRpc* rpc) return -1; } - if (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length) + while (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length) { status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag), header->frag_length - Stream_GetPosition(rpc->client->RecvFrag)); if (status < 0) { - fprintf(stderr, "rpc_client_frag_read: error reading fragment body\n"); + fprintf(stderr, "%s: error reading fragment body\n", __FUNCTION__); return -1; } + if (!status) + return 0; + Stream_Seek(rpc->client->RecvFrag, status); } - } - else - { - return status; - } - if (status < 0) - return -1; - - status = Stream_GetPosition(rpc->client->RecvFrag) - position; - - if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length) - { - /* complete fragment received */ - - Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag); - Stream_SetPosition(rpc->client->RecvFrag, 0); - - Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag); - rpc->client->RecvFrag = NULL; - - if (rpc_client_on_fragment_received_event(rpc) < 0) + if (status < 0) return -1; + + status = Stream_GetPosition(rpc->client->RecvFrag) - position; + + if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length) + { + /* complete fragment received */ + + Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag); + Stream_SetPosition(rpc->client->RecvFrag, 0); + + Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag); + rpc->client->RecvFrag = NULL; + + if (rpc_client_on_fragment_received_event(rpc) < 0) + return -1; + } } - return status; + return 0; } /** @@ -349,13 +354,12 @@ RpcClientCall* rpc_client_call_new(UINT32 CallId, UINT32 OpNum) RpcClientCall* clientCall; clientCall = (RpcClientCall*) malloc(sizeof(RpcClientCall)); + if (!clientCall) + return NULL; - if (clientCall) - { - clientCall->CallId = CallId; - clientCall->OpNum = OpNum; - clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS; - } + clientCall->CallId = CallId; + clientCall->OpNum = OpNum; + clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS; return clientCall; } @@ -371,16 +375,22 @@ int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) int status; pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU)); - pdu->s = Stream_New(buffer, length); + if (!pdu) + return -1; - Queue_Enqueue(rpc->client->SendQueue, pdu); + pdu->s = Stream_New(buffer, length); + if (!pdu->s) + goto out_free; + + if (!Queue_Enqueue(rpc->client->SendQueue, pdu)) + goto out_free_stream; if (rpc->client->SynchronousSend) { status = WaitForSingleObject(rpc->client->PduSentEvent, SYNCHRONOUS_TIMEOUT); if (status == WAIT_TIMEOUT) { - fprintf(stderr, "rpc_send_enqueue_pdu: timed out waiting for pdu sent event\n"); + fprintf(stderr, "%s: timed out waiting for pdu sent event %p\n", __FUNCTION__, rpc->client->PduSentEvent); return -1; } @@ -388,6 +398,12 @@ int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) } return 0; + +out_free_stream: + Stream_Free(pdu->s, TRUE); +out_free: + free(pdu); + return -1; } int rpc_send_dequeue_pdu(rdpRpc* rpc) @@ -396,13 +412,14 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc) RPC_PDU* pdu; RpcClientCall* clientCall; rpcconn_common_hdr_t* header; + RpcInChannel *inChannel; pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->SendQueue); - if (!pdu) return 0; - WaitForSingleObject(rpc->VirtualConnection->DefaultInChannel->Mutex, INFINITE); + inChannel = rpc->VirtualConnection->DefaultInChannel; + WaitForSingleObject(inChannel->Mutex, INFINITE); status = rpc_in_write(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s)); @@ -410,7 +427,7 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc) clientCall = rpc_client_call_find_by_id(rpc, header->call_id); clientCall->State = RPC_CLIENT_CALL_STATE_DISPATCHED; - ReleaseMutex(rpc->VirtualConnection->DefaultInChannel->Mutex); + ReleaseMutex(inChannel->Mutex); /* * This protocol specifies that only RPC PDUs are subject to the flow control abstract @@ -421,8 +438,8 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc) if (header->ptype == PTYPE_REQUEST) { - rpc->VirtualConnection->DefaultInChannel->BytesSent += status; - rpc->VirtualConnection->DefaultInChannel->SenderAvailableWindow -= status; + inChannel->BytesSent += status; + inChannel->SenderAvailableWindow -= status; } Stream_Free(pdu->s, TRUE); @@ -440,57 +457,48 @@ RPC_PDU* rpc_recv_dequeue_pdu(rdpRpc* rpc) DWORD dwMilliseconds; DWORD result; - pdu = NULL; - dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0; + dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT * 4 : 0; result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds); if (result == WAIT_TIMEOUT) { - fprintf(stderr, "rpc_recv_dequeue_pdu: timed out waiting for receive event\n"); + fprintf(stderr, "%s: timed out waiting for receive event\n", __FUNCTION__); return NULL; } - if (result == WAIT_OBJECT_0) - { - pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->ReceiveQueue); + if (result != WAIT_OBJECT_0) + return NULL; + + pdu = (RPC_PDU *)Queue_Dequeue(rpc->client->ReceiveQueue); #ifdef WITH_DEBUG_TSG - if (pdu) - { - fprintf(stderr, "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId); - winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s)); - fprintf(stderr, "\n"); - } -#endif - - return pdu; + if (pdu) + { + fprintf(stderr, "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId); + winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s)); + fprintf(stderr, "\n"); } + else + { + fprintf(stderr, "Receiving a NULL PDU\n"); + } +#endif return pdu; } RPC_PDU* rpc_recv_peek_pdu(rdpRpc* rpc) { - RPC_PDU* pdu; DWORD dwMilliseconds; DWORD result; - pdu = NULL; dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0; result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds); - if (result == WAIT_TIMEOUT) - { + if (result != WAIT_OBJECT_0) return NULL; - } - if (result == WAIT_OBJECT_0) - { - pdu = (RPC_PDU*) Queue_Peek(rpc->client->ReceiveQueue); - return pdu; - } - - return pdu; + return (RPC_PDU *)Queue_Peek(rpc->client->ReceiveQueue); } static void* rpc_client_thread(void* arg) @@ -500,40 +508,52 @@ static void* rpc_client_thread(void* arg) DWORD nCount; HANDLE events[3]; HANDLE ReadEvent; + int fd; rpc = (rdpRpc*) arg; + fd = BIO_get_fd(rpc->TlsOut->bio, NULL); - ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, rpc->TlsOut->sockfd); + ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, fd); nCount = 0; events[nCount++] = rpc->client->StopEvent; events[nCount++] = Queue_Event(rpc->client->SendQueue); events[nCount++] = ReadEvent; + /* Do a first free run in case some bytes were set from the HTTP headers. + * We also have to do it because most of the time the underlying socket has notified, + * and the ssl layer has eaten all bytes, so we won't be notified any more even if the + * bytes are buffered locally + */ + if (rpc_client_on_read_event(rpc) < 0) + { + fprintf(stderr, "%s: an error occured when treating first packet\n", __FUNCTION__); + goto out; + } + while (rpc->transport->layer != TRANSPORT_LAYER_CLOSED) { status = WaitForMultipleObjects(nCount, events, FALSE, 100); - if (status != WAIT_TIMEOUT) + if (status == WAIT_TIMEOUT) + continue; + + if (WaitForSingleObject(rpc->client->StopEvent, 0) == WAIT_OBJECT_0) + break; + + if (WaitForSingleObject(ReadEvent, 0) == WAIT_OBJECT_0) { - if (WaitForSingleObject(rpc->client->StopEvent, 0) == WAIT_OBJECT_0) - { + if (rpc_client_on_read_event(rpc) < 0) break; - } + } - if (WaitForSingleObject(ReadEvent, 0) == WAIT_OBJECT_0) - { - if (rpc_client_on_read_event(rpc) < 0) - break; - } - - if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0) - { - rpc_send_dequeue_pdu(rpc); - } + if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0) + { + rpc_send_dequeue_pdu(rpc); } } +out: CloseHandle(ReadEvent); return NULL; @@ -541,6 +561,9 @@ static void* rpc_client_thread(void* arg) static void rpc_pdu_free(RPC_PDU* pdu) { + if (!pdu) + return; + Stream_Free(pdu->s, TRUE); free(pdu); } @@ -554,35 +577,55 @@ int rpc_client_new(rdpRpc* rpc) { RpcClient* client = NULL; - client = (RpcClient*) calloc(1, sizeof(RpcClient)); - - if (client) - { - client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL); - client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL); - - client->SendQueue = Queue_New(TRUE, -1, -1); - Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; - - client->pdu = NULL; - client->ReceivePool = Queue_New(TRUE, -1, -1); - client->ReceiveQueue = Queue_New(TRUE, -1, -1); - Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; - Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; - - client->RecvFrag = NULL; - client->FragmentPool = Queue_New(TRUE, -1, -1); - client->FragmentQueue = Queue_New(TRUE, -1, -1); - - Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free; - Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free; - - client->ClientCallList = ArrayList_New(TRUE); - ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free; - } - + client = (RpcClient *)calloc(1, sizeof(RpcClient)); rpc->client = client; + if (!client) + return -1; + client->Thread = CreateThread(NULL, 0, + (LPTHREAD_START_ROUTINE) rpc_client_thread, + rpc, CREATE_SUSPENDED, NULL); + if (!client->Thread) + return -1; + + client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + if (!client->StopEvent) + return -1; + client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + if (!client->PduSentEvent) + return -1; + + client->SendQueue = Queue_New(TRUE, -1, -1); + if (!client->SendQueue) + return -1; + Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; + + client->pdu = NULL; + client->ReceivePool = Queue_New(TRUE, -1, -1); + if (!client->ReceivePool) + return -1; + Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; + + client->ReceiveQueue = Queue_New(TRUE, -1, -1); + if (!client->ReceiveQueue) + return -1; + Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; + + client->RecvFrag = NULL; + client->FragmentPool = Queue_New(TRUE, -1, -1); + if (!client->FragmentPool) + return -1; + Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free; + + client->FragmentQueue = Queue_New(TRUE, -1, -1); + if (!client->FragmentQueue) + return -1; + Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free; + + client->ClientCallList = ArrayList_New(TRUE); + if (!client->ClientCallList) + return -1; + ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free; return 0; } @@ -604,9 +647,7 @@ int rpc_client_stop(rdpRpc* rpc) rpc->client->Thread = NULL; } - rpc_client_free(rpc); - - return 0; + return rpc_client_free(rpc); } int rpc_client_free(rdpRpc* rpc) @@ -615,31 +656,39 @@ int rpc_client_free(rdpRpc* rpc) client = rpc->client; - if (client) - { + if (!client) + return 0; + + if (client->SendQueue) Queue_Free(client->SendQueue); - if (client->RecvFrag) - rpc_fragment_free(client->RecvFrag); + if (client->RecvFrag) + rpc_fragment_free(client->RecvFrag); + if (client->FragmentPool) Queue_Free(client->FragmentPool); + if (client->FragmentQueue) Queue_Free(client->FragmentQueue); - if (client->pdu) - rpc_pdu_free(client->pdu); + if (client->pdu) + rpc_pdu_free(client->pdu); + if (client->ReceivePool) Queue_Free(client->ReceivePool); + if (client->ReceiveQueue) Queue_Free(client->ReceiveQueue); + if (client->ClientCallList) ArrayList_Free(client->ClientCallList); + if (client->StopEvent) CloseHandle(client->StopEvent); + if (client->PduSentEvent) CloseHandle(client->PduSentEvent); + if (client->Thread) CloseHandle(client->Thread); - free(client); - } - + free(client); return 0; } diff --git a/libfreerdp/core/gateway/rts.c b/libfreerdp/core/gateway/rts.c index 42ce2ad4e..d57a4240d 100644 --- a/libfreerdp/core/gateway/rts.c +++ b/libfreerdp/core/gateway/rts.c @@ -93,25 +93,25 @@ BOOL rts_connect(rdpRpc* rpc) if (!rpc_ntlm_http_out_connect(rpc)) { - fprintf(stderr, "rpc_out_connect_http error!\n"); + fprintf(stderr, "%s: rpc_out_connect_http error!\n", __FUNCTION__); return FALSE; } if (rts_send_CONN_A1_pdu(rpc) != 0) { - fprintf(stderr, "rpc_send_CONN_A1_pdu error!\n"); + fprintf(stderr, "%s: rpc_send_CONN_A1_pdu error!\n", __FUNCTION__); return FALSE; } if (!rpc_ntlm_http_in_connect(rpc)) { - fprintf(stderr, "rpc_in_connect_http error!\n"); + fprintf(stderr, "%s: rpc_in_connect_http error!\n", __FUNCTION__); return FALSE; } - if (rts_send_CONN_B1_pdu(rpc) != 0) + if (rts_send_CONN_B1_pdu(rpc) < 0) { - fprintf(stderr, "rpc_send_CONN_B1_pdu error!\n"); + fprintf(stderr, "%s: rpc_send_CONN_B1_pdu error!\n", __FUNCTION__); return FALSE; } @@ -147,10 +147,15 @@ BOOL rts_connect(rdpRpc* rpc) */ http_response = http_response_recv(rpc->TlsOut); + if (!http_response) + { + fprintf(stderr, "%s: unable to retrieve OUT Channel Response!\n", __FUNCTION__); + return FALSE; + } if (http_response->StatusCode != HTTP_STATUS_OK) { - fprintf(stderr, "rts_connect error! Status Code: %d\n", http_response->StatusCode); + fprintf(stderr, "%s: error! Status Code: %d\n", __FUNCTION__, http_response->StatusCode); http_response_print(http_response); http_response_free(http_response); @@ -170,6 +175,14 @@ BOOL rts_connect(rdpRpc* rpc) return FALSE; } + if (http_response->bodyLen) + { + /* inject bytes we have read in the body as a received packet for the RPC client */ + rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc); + Stream_EnsureCapacity(rpc->client->RecvFrag, http_response->bodyLen); + CopyMemory(rpc->client->RecvFrag, http_response->BodyContent, http_response->bodyLen); + } + //http_response_print(http_response); http_response_free(http_response); @@ -195,7 +208,6 @@ BOOL rts_connect(rdpRpc* rpc) rpc_client_start(rpc); pdu = rpc_recv_dequeue_pdu(rpc); - if (!pdu) return FALSE; @@ -203,7 +215,7 @@ BOOL rts_connect(rdpRpc* rpc) if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_A3_SIGNATURE, rts)) { - fprintf(stderr, "Unexpected RTS PDU: Expected CONN/A3\n"); + fprintf(stderr, "%s: unexpected RTS PDU: Expected CONN/A3\n", __FUNCTION__); return FALSE; } @@ -236,7 +248,6 @@ BOOL rts_connect(rdpRpc* rpc) */ pdu = rpc_recv_dequeue_pdu(rpc); - if (!pdu) return FALSE; @@ -244,7 +255,7 @@ BOOL rts_connect(rdpRpc* rpc) if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_C2_SIGNATURE, rts)) { - fprintf(stderr, "Unexpected RTS PDU: Expected CONN/C2\n"); + fprintf(stderr, "%s: unexpected RTS PDU: Expected CONN/C2\n", __FUNCTION__); return FALSE; } @@ -261,7 +272,7 @@ BOOL rts_connect(rdpRpc* rpc) return TRUE; } -#if defined WITH_DEBUG_RTS && 0 +#ifdef WITH_DEBUG_RTS static const char* const RTS_CMD_STRINGS[] = { @@ -317,6 +328,7 @@ static const char* const RTS_CMD_STRINGS[] = void rts_pdu_header_init(rpcconn_rts_hdr_t* header) { + ZeroMemory(header, sizeof(*header)); header->rpc_vers = 5; header->rpc_vers_minor = 0; header->ptype = PTYPE_RTS; @@ -681,6 +693,8 @@ int rts_send_CONN_A1_pdu(rdpRpc* rpc) ReceiveWindowSize = rpc->VirtualConnection->DefaultOutChannel->ReceiveWindow; buffer = (BYTE*) malloc(header.frag_length); + if (!buffer) + return -1; CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ rts_version_command_write(&buffer[20]); /* Version (8 bytes) */ @@ -718,6 +732,7 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc) BYTE* INChannelCookie; BYTE* AssociationGroupId; BYTE* VirtualConnectionCookie; + int status; rts_pdu_header_init(&header); header.frag_length = 104; @@ -734,6 +749,8 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc) AssociationGroupId = (BYTE*) &(rpc->VirtualConnection->AssociationGroupId); buffer = (BYTE*) malloc(header.frag_length); + if (!buffer) + return -1; CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ rts_version_command_write(&buffer[20]); /* Version (8 bytes) */ @@ -745,11 +762,11 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc) length = header.frag_length; - rpc_in_write(rpc, buffer, length); + status = rpc_in_write(rpc, buffer, length); free(buffer); - return 0; + return status; } /* CONN/C Sequence */ @@ -795,12 +812,15 @@ int rts_send_keep_alive_pdu(rdpRpc* rpc) DEBUG_RPC("Sending Keep-Alive RTS PDU"); buffer = (BYTE*) malloc(header.frag_length); + if (!buffer) + return -1; CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ rts_client_keepalive_command_write(&buffer[20], rpc->CurrentKeepAliveInterval); /* ClientKeepAlive (8 bytes) */ length = header.frag_length; - rpc_in_write(rpc, buffer, length); + if (rpc_in_write(rpc, buffer, length) < 0) + return -1; free(buffer); return length; @@ -830,6 +850,8 @@ int rts_send_flow_control_ack_pdu(rdpRpc* rpc) rpc->VirtualConnection->DefaultOutChannel->AvailableWindowAdvertised; buffer = (BYTE*) malloc(header.frag_length); + if (!buffer) + return -1; CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ rts_destination_command_write(&buffer[20], FDOutProxy); /* Destination Command (8 bytes) */ @@ -839,7 +861,8 @@ int rts_send_flow_control_ack_pdu(rdpRpc* rpc) length = header.frag_length; - rpc_in_write(rpc, buffer, length); + if (rpc_in_write(rpc, buffer, length) < 0) + return -1; free(buffer); return 0; @@ -923,12 +946,15 @@ int rts_send_ping_pdu(rdpRpc* rpc) DEBUG_RPC("Sending Ping RTS PDU"); buffer = (BYTE*) malloc(header.frag_length); + if (!buffer) + return -1; CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ length = header.frag_length; - rpc_in_write(rpc, buffer, length); + if (rpc_in_write(rpc, buffer, length) < 0) + return -1; free(buffer); return length; @@ -1020,22 +1046,18 @@ int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) rts_extract_pdu_signature(rpc, &signature, rts); SignatureId = rts_identify_pdu_signature(rpc, &signature, NULL); - if (SignatureId == RTS_PDU_FLOW_CONTROL_ACK) + switch (SignatureId) { - return rts_recv_flow_control_ack_pdu(rpc, buffer, length); - } - else if (SignatureId == RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION) - { - return rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length); - } - else if (SignatureId == RTS_PDU_PING) - { - rts_send_ping_pdu(rpc); - } - else - { - fprintf(stderr, "Unimplemented signature id: 0x%08X\n", SignatureId); - rts_print_pdu_signature(rpc, &signature); + case RTS_PDU_FLOW_CONTROL_ACK: + return rts_recv_flow_control_ack_pdu(rpc, buffer, length); + case RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION: + return rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length); + case RTS_PDU_PING: + return rts_send_ping_pdu(rpc); + default: + fprintf(stderr, "%s: unimplemented signature id: 0x%08X\n", __FUNCTION__, SignatureId); + rts_print_pdu_signature(rpc, &signature); + break; } return 0; diff --git a/libfreerdp/core/gateway/rts_signature.c b/libfreerdp/core/gateway/rts_signature.c index 34598fe71..47242ca63 100644 --- a/libfreerdp/core/gateway/rts_signature.c +++ b/libfreerdp/core/gateway/rts_signature.c @@ -234,7 +234,6 @@ BOOL rts_match_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_rt return FALSE; status = rts_command_length(rpc, CommandType, &buffer[offset], length); - if (status < 0) return FALSE; @@ -272,7 +271,6 @@ int rts_extract_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_r signature->CommandTypes[i] = CommandType; status = rts_command_length(rpc, CommandType, &buffer[offset], length); - if (status < 0) return FALSE; @@ -294,22 +292,22 @@ UINT32 rts_identify_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, RTS_P { pSignature = RTS_PDU_SIGNATURE_TABLE[i].Signature; - if (signature->Flags == pSignature->Flags) + if (signature->Flags != pSignature->Flags) + continue; + + if (signature->NumberOfCommands != pSignature->NumberOfCommands) + continue; + + for (j = 0; j < signature->NumberOfCommands; j++) { - if (signature->NumberOfCommands == pSignature->NumberOfCommands) - { - for (j = 0; j < signature->NumberOfCommands; j++) - { - if (signature->CommandTypes[j] != pSignature->CommandTypes[j]) - continue; - } - - if (entry) - *entry = &RTS_PDU_SIGNATURE_TABLE[i]; - - return RTS_PDU_SIGNATURE_TABLE[i].SignatureId; - } + if (signature->CommandTypes[j] != pSignature->CommandTypes[j]) + continue; } + + if (entry) + *entry = &RTS_PDU_SIGNATURE_TABLE[i]; + + return RTS_PDU_SIGNATURE_TABLE[i].SignatureId; } return 0; diff --git a/libfreerdp/core/gateway/tsg.c b/libfreerdp/core/gateway/tsg.c index f130f73ab..5dd68886d 100644 --- a/libfreerdp/core/gateway/tsg.c +++ b/libfreerdp/core/gateway/tsg.c @@ -33,9 +33,9 @@ #include #include "rpc_client.h" - #include "tsg.h" + /** * RPC Functions: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378623/ * Remote Procedure Call: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378651/ @@ -96,7 +96,9 @@ DWORD TsProxySendToServer(handle_t IDL_handle, byte pRpcMessage[], UINT32 count, } length = 28 + totalDataBytes; - buffer = (BYTE*) malloc(length); + buffer = (BYTE*) calloc(1, length); + if (!buffer) + return -1; s = Stream_New(buffer, length); @@ -228,8 +230,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) if (!(pdu->Flags & RPC_PDU_FLAG_STUB)) buffer = &buffer[24]; - packet = (PTSG_PACKET) malloc(sizeof(TSG_PACKET)); - ZeroMemory(packet, sizeof(TSG_PACKET)); + packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET)); + if (!packet) + return FALSE; offset = 4; // Skip Packet Pointer packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */ @@ -237,8 +240,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) if ((packet->packetId == TSG_PACKET_TYPE_CAPS_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_CAPS_RESPONSE)) { - packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) malloc(sizeof(TSG_PACKET_CAPS_RESPONSE)); - ZeroMemory(packetCapsResponse, sizeof(TSG_PACKET_CAPS_RESPONSE)); + packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) calloc(1, sizeof(TSG_PACKET_CAPS_RESPONSE)); + if (!packetCapsResponse) // TODO: correct cleanup + return FALSE; packet->tsgPacket.packetCapsResponse = packetCapsResponse; /* PacketQuarResponsePtr (4 bytes) */ @@ -258,8 +262,7 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) IsMessagePresent = *((UINT32*) &buffer[offset]); offset += 4; MessageSwitchValue = *((UINT32*) &buffer[offset]); - DEBUG_TSG("IsMessagePresent %d MessageSwitchValue %d", - IsMessagePresent, MessageSwitchValue); + DEBUG_TSG("IsMessagePresent %d MessageSwitchValue %d", IsMessagePresent, MessageSwitchValue); offset += 4; } @@ -289,8 +292,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) offset += 4; } - versionCaps = (PTSG_PACKET_VERSIONCAPS) malloc(sizeof(TSG_PACKET_VERSIONCAPS)); - ZeroMemory(versionCaps, sizeof(TSG_PACKET_VERSIONCAPS)); + versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS)); + if (!versionCaps) // TODO: correct cleanup + return FALSE; packetCapsResponse->pktQuarEncResponse.versionCaps = versionCaps; versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */ @@ -317,8 +321,10 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) /* 4-byte alignment */ rpc_offset_align(&offset, 4); - tsgCaps = (PTSG_PACKET_CAPABILITIES) malloc(sizeof(TSG_PACKET_CAPABILITIES)); - ZeroMemory(tsgCaps, sizeof(TSG_PACKET_CAPABILITIES)); + tsgCaps = (PTSG_PACKET_CAPABILITIES) calloc(1, sizeof(TSG_PACKET_CAPABILITIES)); + if (!tsgCaps) + return FALSE; + versionCaps->tsgCaps = tsgCaps; offset += 4; /* MaxCount (4 bytes) */ @@ -406,8 +412,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) } else if ((packet->packetId == TSG_PACKET_TYPE_QUARENC_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_QUARENC_RESPONSE)) { - packetQuarEncResponse = (PTSG_PACKET_QUARENC_RESPONSE) malloc(sizeof(TSG_PACKET_QUARENC_RESPONSE)); - ZeroMemory(packetQuarEncResponse, sizeof(TSG_PACKET_QUARENC_RESPONSE)); + packetQuarEncResponse = (PTSG_PACKET_QUARENC_RESPONSE) calloc(1, sizeof(TSG_PACKET_QUARENC_RESPONSE)); + if (!packetQuarEncResponse) // TODO: handle cleanup + return FALSE; packet->tsgPacket.packetQuarEncResponse = packetQuarEncResponse; /* PacketQuarResponsePtr (4 bytes) */ @@ -443,8 +450,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) offset += 4; } - versionCaps = (PTSG_PACKET_VERSIONCAPS) malloc(sizeof(TSG_PACKET_VERSIONCAPS)); - ZeroMemory(versionCaps, sizeof(TSG_PACKET_VERSIONCAPS)); + versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS)); + if (!versionCaps) // TODO: handle cleanup + return FALSE; packetQuarEncResponse->versionCaps = versionCaps; versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */ @@ -779,8 +787,9 @@ BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu) if (!(pdu->Flags & RPC_PDU_FLAG_STUB)) buffer = &buffer[24]; - packet = (PTSG_PACKET) malloc(sizeof(TSG_PACKET)); - ZeroMemory(packet, sizeof(TSG_PACKET)); + packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET)); + if (!packet) + return FALSE; offset = 4; packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */ @@ -923,6 +932,8 @@ BOOL TsProxyCreateChannelWriteRequest(rdpTsg* tsg, PTUNNEL_CONTEXT_HANDLE_NOSERI length = 60 + (count * 2); buffer = (BYTE*) malloc(length); + if (!buffer) + return FALSE; /* TunnelContext */ handle = (CONTEXT_HANDLE*) tunnelContext; @@ -1526,48 +1537,53 @@ int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length) return CopyLength; } - else + + + tsg->pdu = rpc_recv_peek_pdu(rpc); + if (!tsg->pdu) { - tsg->pdu = rpc_recv_peek_pdu(rpc); + if (!tsg->rpc->client->SynchronousReceive) + return 0; - if (!tsg->pdu) - { - if (tsg->rpc->client->SynchronousReceive) - return tsg_read(tsg, data, length); - else - return 0; - } - - tsg->PendingPdu = TRUE; - tsg->BytesAvailable = Stream_Length(tsg->pdu->s); - tsg->BytesRead = 0; - - CopyLength = (length < tsg->BytesAvailable) ? length : tsg->BytesAvailable; - - CopyMemory(data, &tsg->pdu->s->buffer[tsg->BytesRead], CopyLength); - tsg->BytesAvailable -= CopyLength; - tsg->BytesRead += CopyLength; - - if (tsg->BytesAvailable < 1) - { - tsg->PendingPdu = FALSE; - rpc_recv_dequeue_pdu(rpc); - rpc_client_receive_pool_return(rpc, tsg->pdu); - } - - return CopyLength; + // weird !!!! + return tsg_read(tsg, data, length); } + + tsg->PendingPdu = TRUE; + tsg->BytesAvailable = Stream_Length(tsg->pdu->s); + tsg->BytesRead = 0; + + CopyLength = (length < tsg->BytesAvailable) ? length : tsg->BytesAvailable; + + CopyMemory(data, &tsg->pdu->s->buffer[tsg->BytesRead], CopyLength); + tsg->BytesAvailable -= CopyLength; + tsg->BytesRead += CopyLength; + + if (tsg->BytesAvailable < 1) + { + tsg->PendingPdu = FALSE; + rpc_recv_dequeue_pdu(rpc); + rpc_client_receive_pool_return(rpc, tsg->pdu); + } + + return CopyLength; + } int tsg_write(rdpTsg* tsg, BYTE* data, UINT32 length) { + int status; + if (tsg->rpc->transport->layer == TRANSPORT_LAYER_CLOSED) { - fprintf(stderr, "tsg_write error: connection lost\n"); + fprintf(stderr, "%s: error, connection lost\n", __FUNCTION__); return -1; } - return TsProxySendToServer((handle_t) tsg, data, 1, &length); + status = TsProxySendToServer((handle_t) tsg, data, 1, &length); + if (status < 0) + return -1; + return length; } BOOL tsg_set_blocking_mode(rdpTsg* tsg, BOOL blocking) @@ -1584,18 +1600,21 @@ rdpTsg* tsg_new(rdpTransport* transport) { rdpTsg* tsg; - tsg = (rdpTsg*) malloc(sizeof(rdpTsg)); - ZeroMemory(tsg, sizeof(rdpTsg)); - - if (tsg != NULL) - { - tsg->transport = transport; - tsg->settings = transport->settings; - tsg->rpc = rpc_new(tsg->transport); - tsg->PendingPdu = FALSE; - } + tsg = (rdpTsg*) calloc(1, sizeof(rdpTsg)); + if (!tsg) + return NULL; + tsg->transport = transport; + tsg->settings = transport->settings; + tsg->rpc = rpc_new(tsg->transport); + if (!tsg->rpc) + goto out_free; + tsg->PendingPdu = FALSE; return tsg; + +out_free: + free(tsg); + return NULL; } void tsg_free(rdpTsg* tsg) diff --git a/libfreerdp/core/license.c b/libfreerdp/core/license.c index dc96014a9..4d3a53a0f 100644 --- a/libfreerdp/core/license.c +++ b/libfreerdp/core/license.c @@ -241,7 +241,7 @@ int license_recv(rdpLicense* license, wStream* s) if (!rdp_read_header(license->rdp, s, &length, &channelId)) { - fprintf(stderr, "Incorrect RDP header.\n"); + fprintf(stderr, "%s: Incorrect RDP header.\n", __FUNCTION__); return -1; } @@ -252,7 +252,7 @@ int license_recv(rdpLicense* license, wStream* s) { if (!rdp_decrypt(license->rdp, s, length - 4, securityFlags)) { - fprintf(stderr, "rdp_decrypt failed\n"); + fprintf(stderr, "%s: rdp_decrypt failed\n", __FUNCTION__); return -1; } } @@ -268,7 +268,7 @@ int license_recv(rdpLicense* license, wStream* s) if (status < 0) { - fprintf(stderr, "Unexpected license packet.\n"); + fprintf(stderr, "%s: unexpected license packet.\n", __FUNCTION__); return status; } @@ -308,7 +308,7 @@ int license_recv(rdpLicense* license, wStream* s) break; default: - fprintf(stderr, "invalid bMsgType:%d\n", bMsgType); + fprintf(stderr, "%s: invalid bMsgType:%d\n", __FUNCTION__, bMsgType); return FALSE; } diff --git a/libfreerdp/core/mcs.c b/libfreerdp/core/mcs.c index 31620b8c8..c47660f19 100644 --- a/libfreerdp/core/mcs.c +++ b/libfreerdp/core/mcs.c @@ -186,7 +186,7 @@ static const char* const mcs_result_enumerated[] = int mcs_initialize_client_channels(rdpMcs* mcs, rdpSettings* settings) { - int index; + UINT32 index; mcs->channelCount = settings->ChannelCount; if (mcs->channelCount > mcs->channelMaxCount) @@ -1056,26 +1056,29 @@ rdpMcs* mcs_new(rdpTransport* transport) { rdpMcs* mcs; - mcs = (rdpMcs*) malloc(sizeof(rdpMcs)); + mcs = (rdpMcs *)calloc(1, sizeof(rdpMcs)); + if (!mcs) + return NULL; - if (mcs) - { - ZeroMemory(mcs, sizeof(rdpMcs)); + mcs->transport = transport; + mcs->settings = transport->settings; - mcs->transport = transport; - mcs->settings = transport->settings; + mcs_init_domain_parameters(&mcs->targetParameters, 34, 2, 0, 0xFFFF); + mcs_init_domain_parameters(&mcs->minimumParameters, 1, 1, 1, 0x420); + mcs_init_domain_parameters(&mcs->maximumParameters, 0xFFFF, 0xFC17, 0xFFFF, 0xFFFF); + mcs_init_domain_parameters(&mcs->domainParameters, 0, 0, 0, 0xFFFF); - mcs_init_domain_parameters(&mcs->targetParameters, 34, 2, 0, 0xFFFF); - mcs_init_domain_parameters(&mcs->minimumParameters, 1, 1, 1, 0x420); - mcs_init_domain_parameters(&mcs->maximumParameters, 0xFFFF, 0xFC17, 0xFFFF, 0xFFFF); - mcs_init_domain_parameters(&mcs->domainParameters, 0, 0, 0, 0xFFFF); - - mcs->channelCount = 0; - mcs->channelMaxCount = CHANNEL_MAX_COUNT; - mcs->channels = (rdpMcsChannel*) calloc(mcs->channelMaxCount, sizeof(rdpMcsChannel)); - } + mcs->channelCount = 0; + mcs->channelMaxCount = CHANNEL_MAX_COUNT; + mcs->channels = (rdpMcsChannel *)calloc(mcs->channelMaxCount, sizeof(rdpMcsChannel)); + if (!mcs->channels) + goto out_free; return mcs; + +out_free: + free(mcs); + return NULL; } /** diff --git a/libfreerdp/core/peer.c b/libfreerdp/core/peer.c index e1662d335..bc7431f47 100644 --- a/libfreerdp/core/peer.c +++ b/libfreerdp/core/peer.c @@ -52,13 +52,13 @@ static BOOL freerdp_peer_initialize(freerdp_peer* client) fprintf(stderr, "%s: inavlid RDP key file %s\n", __FUNCTION__, settings->RdpKeyFile); return FALSE; } + if (settings->RdpServerRsaKey->ModulusLength > 256) { fprintf(stderr, "%s: Key sizes > 2048 are currently not supported for RDP security.\n", __FUNCTION__); fprintf(stderr, "%s: Set a different key file than %s\n", __FUNCTION__, settings->RdpKeyFile); exit(1); } - } return TRUE; @@ -77,12 +77,13 @@ static HANDLE freerdp_peer_get_event_handle(freerdp_peer* client) return client->context->rdp->transport->TcpIn->event; } -static BOOL freerdp_peer_check_fds(freerdp_peer* client) + +static BOOL freerdp_peer_check_fds(freerdp_peer* peer) { int status; rdpRdp* rdp; - rdp = client->context->rdp; + rdp = peer->context->rdp; status = rdp_check_fds(rdp); @@ -413,6 +414,19 @@ static int freerdp_peer_send_channel_data(freerdp_peer* client, UINT16 channelId return rdp_send_channel_data(client->context->rdp, channelId, data, size); } +static BOOL freerdp_peer_is_write_blocked(freerdp_peer* peer) +{ + return tranport_is_write_blocked(peer->context->rdp->transport); +} + +static int freerdp_peer_drain_output_buffer(freerdp_peer* peer) +{ + + rdpTransport *transport = peer->context->rdp->transport; + + return tranport_drain_output_buffer(transport); +} + void freerdp_peer_context_new(freerdp_peer* client) { rdpRdp* rdp; @@ -445,6 +459,9 @@ void freerdp_peer_context_new(freerdp_peer* client) rdp->transport->ReceiveExtra = client; transport_set_blocking_mode(rdp->transport, FALSE); + client->IsWriteBlocked = freerdp_peer_is_write_blocked; + client->DrainOutputBuffer = freerdp_peer_drain_output_buffer; + IFCALL(client->ContextNew, client, client->context); } @@ -473,6 +490,8 @@ freerdp_peer* freerdp_peer_new(int sockfd) client->Close = freerdp_peer_close; client->Disconnect = freerdp_peer_disconnect; client->SendChannelData = freerdp_peer_send_channel_data; + client->IsWriteBlocked = freerdp_peer_is_write_blocked; + client->DrainOutputBuffer = freerdp_peer_drain_output_buffer; } return client; @@ -480,10 +499,10 @@ freerdp_peer* freerdp_peer_new(int sockfd) void freerdp_peer_free(freerdp_peer* client) { - if (client) - { - rdp_free(client->context->rdp); - free(client->context); - free(client); - } + if (!client) + return; + + rdp_free(client->context->rdp); + free(client->context); + free(client); } diff --git a/libfreerdp/core/server.c b/libfreerdp/core/server.c index 59657d51b..6242b3c2c 100644 --- a/libfreerdp/core/server.c +++ b/libfreerdp/core/server.c @@ -358,7 +358,7 @@ static void WTSProcessChannelData(rdpPeerChannel* channel, UINT16 channelId, BYT static int WTSReceiveChannelData(freerdp_peer* client, UINT16 channelId, BYTE* data, int size, int flags, int totalSize) { - int i; + UINT32 i; BOOL status = FALSE; rdpPeerChannel* channel; rdpMcs* mcs = client->context->rdp->mcs; @@ -846,8 +846,8 @@ BOOL WINAPI FreeRDP_WTSWaitSystemEvent(HANDLE hServer, DWORD EventMask, DWORD* p HANDLE WINAPI FreeRDP_WTSVirtualChannelOpen(HANDLE hServer, DWORD SessionId, LPSTR pVirtualName) { - int index; int length; + UINT32 index; rdpMcs* mcs; BOOL joined = FALSE; freerdp_peer* client; @@ -910,7 +910,7 @@ HANDLE WINAPI FreeRDP_WTSVirtualChannelOpen(HANDLE hServer, DWORD SessionId, LPS HANDLE WINAPI FreeRDP_WTSVirtualChannelOpenEx(DWORD SessionId, LPSTR pVirtualName, DWORD flags) { - int index; + UINT32 index; wStream* s; rdpMcs* mcs; BOOL joined = FALSE; diff --git a/libfreerdp/core/settings.c b/libfreerdp/core/settings.c index 6538ec7cf..3ae24f03c 100644 --- a/libfreerdp/core/settings.c +++ b/libfreerdp/core/settings.c @@ -209,6 +209,7 @@ rdpSettings* freerdp_settings_new(DWORD flags) ZeroMemory(settings, sizeof(rdpSettings)); settings->ServerMode = (flags & FREERDP_SETTINGS_SERVER_MODE) ? TRUE : FALSE; + settings->WaitForOutputBufferFlush = TRUE; settings->DesktopWidth = 1024; settings->DesktopHeight = 768; @@ -235,6 +236,7 @@ rdpSettings* freerdp_settings_new(DWORD flags) settings->SaltedChecksum = TRUE; settings->ServerPort = 3389; settings->GatewayPort = 443; + settings->GatewayBypassLocal = TRUE; settings->DesktopResize = TRUE; settings->ToggleFullscreen = TRUE; settings->DesktopPosX = 0; @@ -262,6 +264,8 @@ rdpSettings* freerdp_settings_new(DWORD flags) settings->Authentication = TRUE; settings->AuthenticationOnly = FALSE; settings->CredentialsFromStdin = FALSE; + settings->DisableCredentialsDelegation = FALSE; + settings->AuthenticationLevel = 2; settings->ChannelCount = 0; settings->ChannelDefArraySize = 32; @@ -579,6 +583,7 @@ rdpSettings* freerdp_settings_clone(rdpSettings* settings) /* BOOL values */ _settings->ServerMode = settings->ServerMode; /* 16 */ + _settings->WaitForOutputBufferFlush = settings->WaitForOutputBufferFlush; /* 25 */ _settings->NetworkAutoDetect = settings->NetworkAutoDetect; /* 137 */ _settings->SupportAsymetricKeys = settings->SupportAsymetricKeys; /* 138 */ _settings->SupportErrorInfoPdu = settings->SupportErrorInfoPdu; /* 139 */ @@ -628,6 +633,7 @@ rdpSettings* freerdp_settings_clone(rdpSettings* settings) _settings->NegotiateSecurityLayer = settings->NegotiateSecurityLayer; /* 1096 */ _settings->RestrictedAdminModeRequired = settings->RestrictedAdminModeRequired; /* 1097 */ _settings->DisableCredentialsDelegation = settings->DisableCredentialsDelegation; /* 1099 */ + _settings->AuthenticationLevel = settings->AuthenticationLevel; /* 1100 */ _settings->MstscCookieMode = settings->MstscCookieMode; /* 1152 */ _settings->SendPreconnectionPdu = settings->SendPreconnectionPdu; /* 1156 */ _settings->IgnoreCertificate = settings->IgnoreCertificate; /* 1408 */ diff --git a/libfreerdp/core/tcp.c b/libfreerdp/core/tcp.c index 15c417616..f869c4c71 100644 --- a/libfreerdp/core/tcp.c +++ b/libfreerdp/core/tcp.c @@ -66,6 +66,165 @@ #include "tcp.h" +long transport_bio_buffered_callback(BIO* bio, int mode, const char* argp, int argi, long argl, long ret) +{ + return 1; +} + +static int transport_bio_buffered_write(BIO* bio, const char* buf, int num) +{ + int status, ret; + rdpTcp *tcp = (rdpTcp *)bio->ptr; + int nchunks, committedBytes, i; + DataChunk chunks[2]; + + ret = num; + BIO_clear_retry_flags(bio); + tcp->writeBlocked = FALSE; + + /* we directly append extra bytes in the xmit buffer, this could be prevented + * but for now it makes the code more simple. + */ + if (buf && num && !ringbuffer_write(&tcp->xmitBuffer, (const BYTE *)buf, num)) + { + fprintf(stderr, "%s: an error occured when writing(toWrite=%d)\n", __FUNCTION__, num); + return -1; + } + + committedBytes = 0; + nchunks = ringbuffer_peek(&tcp->xmitBuffer, chunks, ringbuffer_used(&tcp->xmitBuffer)); + for (i = 0; i < nchunks; i++) + { + while (chunks[i].size) + { + status = BIO_write(bio->next_bio, chunks[i].data, chunks[i].size); + /*fprintf(stderr, "%s: i=%d/%d size=%d/%d status=%d retry=%d\n", __FUNCTION__, i, nchunks, + chunks[i].size, ringbuffer_used(&tcp->xmitBuffer), status, + BIO_should_retry(bio->next_bio) + );*/ + if (status <= 0) + { + if (BIO_should_retry(bio->next_bio)) + { + tcp->writeBlocked = TRUE; + goto out; /* EWOULDBLOCK */ + } + + /* any other is an error, but we still have to commit written bytes */ + ret = -1; + goto out; + } + + committedBytes += status; + chunks[i].size -= status; + chunks[i].data += status; + } + } + +out: + ringbuffer_commit_read_bytes(&tcp->xmitBuffer, committedBytes); + return ret; +} + +static int transport_bio_buffered_read(BIO* bio, char* buf, int size) +{ + int status; + rdpTcp *tcp = (rdpTcp *)bio->ptr; + + tcp->readBlocked = FALSE; + BIO_clear_retry_flags(bio); + + status = BIO_read(bio->next_bio, buf, size); + /*fprintf(stderr, "%s: size=%d status=%d shouldRetry=%d\n", __FUNCTION__, size, status, BIO_should_retry(bio->next_bio)); */ + + if (status <= 0 && BIO_should_retry(bio->next_bio)) + { + BIO_set_retry_read(bio); + tcp->readBlocked = TRUE; + } + + return status; +} + +static int transport_bio_buffered_puts(BIO* bio, const char* str) +{ + return 1; +} + +static int transport_bio_buffered_gets(BIO* bio, char* str, int size) +{ + return 1; +} + +static long transport_bio_buffered_ctrl(BIO* bio, int cmd, long arg1, void* arg2) +{ + rdpTcp *tcp = (rdpTcp *)bio->ptr; + + switch (cmd) + { + case BIO_CTRL_FLUSH: + return 1; + case BIO_CTRL_WPENDING: + return ringbuffer_used(&tcp->xmitBuffer); + case BIO_CTRL_PENDING: + return 0; + default: + /*fprintf(stderr, "%s: passing to next BIO, bio=%p cmd=%d arg1=%d arg2=%p\n", __FUNCTION__, bio, cmd, arg1, arg2); */ + return BIO_ctrl(bio->next_bio, cmd, arg1, arg2); + } + + return 0; +} + +static int transport_bio_buffered_new(BIO* bio) +{ + bio->init = 1; + bio->num = 0; + bio->ptr = NULL; + bio->flags = 0; + + return 1; +} + +static int transport_bio_buffered_free(BIO* bio) +{ + return 1; +} + + +static BIO_METHOD transport_bio_buffered_socket_methods = +{ + BIO_TYPE_BUFFERED, + "BufferedSocket", + transport_bio_buffered_write, + transport_bio_buffered_read, + transport_bio_buffered_puts, + transport_bio_buffered_gets, + transport_bio_buffered_ctrl, + transport_bio_buffered_new, + transport_bio_buffered_free, + NULL, +}; + +BIO_METHOD* BIO_s_buffered_socket(void) +{ + return &transport_bio_buffered_socket_methods; +} + +BOOL transport_bio_buffered_drain(BIO *bio) +{ + rdpTcp *tcp = (rdpTcp *)bio->ptr; + int status; + + if (!ringbuffer_used(&tcp->xmitBuffer)) + return 1; + + status = transport_bio_buffered_write(bio, NULL, 0); + return status >= 0; +} + + + void tcp_get_ip_address(rdpTcp* tcp) { BYTE* ip; @@ -136,62 +295,65 @@ BOOL tcp_connect(rdpTcp* tcp, const char* hostname, int port) if (hostname[0] == '/') { tcp->sockfd = freerdp_uds_connect(hostname); - if (tcp->sockfd < 0) return FALSE; + + tcp->socketBio = BIO_new_fd(tcp->sockfd, 1); + if (!tcp->socketBio) + return FALSE; } else { - tcp->sockfd = freerdp_tcp_connect(hostname, port); - - if (tcp->sockfd < 0) + tcp->socketBio = BIO_new(BIO_s_connect()); + if (!tcp->socketBio) return FALSE; - SetEventFileDescriptor(tcp->event, tcp->sockfd); + if (BIO_set_conn_hostname(tcp->socketBio, hostname) < 0 || BIO_set_conn_int_port(tcp->socketBio, &port) < 0) + return FALSE; - tcp_get_ip_address(tcp); - tcp_get_mac_address(tcp); + if (BIO_do_connect(tcp->socketBio) <= 0) + return FALSE; - option_value = 1; - option_len = sizeof(option_value); - setsockopt(tcp->sockfd, IPPROTO_TCP, TCP_NODELAY, (void*) &option_value, option_len); - - /* receive buffer must be a least 32 K */ - if (getsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, &option_len) == 0) - { - if (option_value < (1024 * 32)) - { - option_value = 1024 * 32; - option_len = sizeof(option_value); - setsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, option_len); - } - } - - tcp_set_keep_alive_mode(tcp); + tcp->sockfd = BIO_get_fd(tcp->socketBio, NULL); } + SetEventFileDescriptor(tcp->event, tcp->sockfd); + + tcp_get_ip_address(tcp); + tcp_get_mac_address(tcp); + + option_value = 1; + option_len = sizeof(option_value); + if (setsockopt(tcp->sockfd, IPPROTO_TCP, TCP_NODELAY, (void*) &option_value, option_len) < 0) + fprintf(stderr, "%s: unable to set TCP_NODELAY\n", __FUNCTION__); + + /* receive buffer must be a least 32 K */ + if (getsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, &option_len) == 0) + { + if (option_value < (1024 * 32)) + { + option_value = 1024 * 32; + option_len = sizeof(option_value); + if (setsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, option_len) < 0) + { + fprintf(stderr, "%s: unable to set receive buffer len\n", __FUNCTION__); + return FALSE; + } + } + } + + if (!tcp_set_keep_alive_mode(tcp)) + return FALSE; + + tcp->bufferedBio = BIO_new(BIO_s_buffered_socket()); + if (!tcp->bufferedBio) + return FALSE; + tcp->bufferedBio->ptr = tcp; + + tcp->bufferedBio = BIO_push(tcp->bufferedBio, tcp->socketBio); return TRUE; } -int tcp_read(rdpTcp* tcp, BYTE* data, int length) -{ - return freerdp_tcp_read(tcp->sockfd, data, length); -} - -int tcp_write(rdpTcp* tcp, BYTE* data, int length) -{ - return freerdp_tcp_write(tcp->sockfd, data, length); -} - -int tcp_wait_read(rdpTcp* tcp) -{ - return freerdp_tcp_wait_read(tcp->sockfd); -} - -int tcp_wait_write(rdpTcp* tcp) -{ - return freerdp_tcp_wait_write(tcp->sockfd); -} BOOL tcp_disconnect(rdpTcp* tcp) { @@ -209,7 +371,7 @@ BOOL tcp_set_blocking_mode(rdpTcp* tcp, BOOL blocking) if (flags == -1) { - fprintf(stderr, "tcp_set_blocking_mode: fcntl failed.\n"); + fprintf(stderr, "%s: fcntl failed, %s.\n", __FUNCTION__, strerror(errno)); return FALSE; } @@ -297,6 +459,31 @@ int tcp_attach(rdpTcp* tcp, int sockfd) { tcp->sockfd = sockfd; SetEventFileDescriptor(tcp->event, tcp->sockfd); + + ringbuffer_commit_read_bytes(&tcp->xmitBuffer, ringbuffer_used(&tcp->xmitBuffer)); + + if (tcp->socketBio) + { + if (BIO_set_fd(tcp->socketBio, sockfd, 1) < 0) + return -1; + } + else + { + tcp->socketBio = BIO_new_socket(sockfd, 1); + if (!tcp->socketBio) + return -1; + } + + if (!tcp->bufferedBio) + { + tcp->bufferedBio = BIO_new(BIO_s_buffered_socket()); + if (!tcp->bufferedBio) + return FALSE; + tcp->bufferedBio->ptr = tcp; + + tcp->bufferedBio = BIO_push(tcp->bufferedBio, tcp->socketBio); + } + return 0; } @@ -316,25 +503,38 @@ rdpTcp* tcp_new(rdpSettings* settings) { rdpTcp* tcp; - tcp = (rdpTcp*) malloc(sizeof(rdpTcp)); + tcp = (rdpTcp *)calloc(1, sizeof(rdpTcp)); + if (!tcp) + return NULL; - if (tcp) - { - ZeroMemory(tcp, sizeof(rdpTcp)); + if (!ringbuffer_init(&tcp->xmitBuffer, 0x10000)) + goto out_free; - tcp->sockfd = -1; - tcp->settings = settings; - tcp->event = CreateFileDescriptorEvent(NULL, FALSE, FALSE, tcp->sockfd); - } + tcp->sockfd = -1; + tcp->settings = settings; + +#ifndef _WIN32 + tcp->event = CreateFileDescriptorEvent(NULL, FALSE, FALSE, tcp->sockfd); + if (!tcp->event || tcp->event == INVALID_HANDLE_VALUE) + goto out_ringbuffer; +#endif return tcp; +#ifndef _WIN32 +out_ringbuffer: + ringbuffer_destroy(&tcp->xmitBuffer); +#endif +out_free: + free(tcp); + return NULL; } void tcp_free(rdpTcp* tcp) { - if (tcp) - { - CloseHandle(tcp->event); - free(tcp); - } + if (!tcp) + return; + + ringbuffer_destroy(&tcp->xmitBuffer); + CloseHandle(tcp->event); + free(tcp); } diff --git a/libfreerdp/core/tcp.h b/libfreerdp/core/tcp.h index b43fbaf1c..a8b3153b9 100644 --- a/libfreerdp/core/tcp.h +++ b/libfreerdp/core/tcp.h @@ -31,10 +31,15 @@ #include #include +#include +#include + #ifndef MSG_NOSIGNAL #define MSG_NOSIGNAL 0 #endif +#define BIO_TYPE_BUFFERED 66 + typedef struct rdp_tcp rdpTcp; struct rdp_tcp @@ -46,6 +51,12 @@ struct rdp_tcp #ifdef _WIN32 WSAEVENT wsa_event; #endif + BIO *socketBio; + BIO *bufferedBio; + RingBuffer xmitBuffer; + BOOL writeBlocked; + BOOL readBlocked; + HANDLE event; }; diff --git a/libfreerdp/core/transport.c b/libfreerdp/core/transport.c index c194c292c..bb455a927 100644 --- a/libfreerdp/core/transport.c +++ b/libfreerdp/core/transport.c @@ -33,7 +33,9 @@ #include #include +#include +#include #include #include #include @@ -41,6 +43,12 @@ #ifndef _WIN32 #include #include +#include +#include +#endif + +#ifdef HAVE_VALGRIND_MEMCHECK_H +#include #endif #include "tpkt.h" @@ -48,6 +56,7 @@ #include "transport.h" #include "rdp.h" + #define BUFFER_SIZE 16384 static void* transport_client_thread(void* arg); @@ -69,6 +78,7 @@ void transport_attach(rdpTransport* transport, int sockfd) tcp_attach(transport->TcpIn, sockfd); transport->SplitInputOutput = FALSE; transport->TcpOut = transport->TcpIn; + transport->frontBio = transport->TcpIn->bufferedBio; } void transport_stop(rdpTransport* transport) @@ -98,18 +108,9 @@ BOOL transport_disconnect(rdpTransport* transport) transport_stop(transport); - if (transport->layer == TRANSPORT_LAYER_TLS) - status &= tls_disconnect(transport->TlsIn); - - if ((transport->layer == TRANSPORT_LAYER_TSG) || (transport->layer == TRANSPORT_LAYER_TSG_TLS)) - { - status &= tsg_disconnect(transport->tsg); - } - else - { - status &= tcp_disconnect(transport->TcpIn); - } + BIO_free_all(transport->frontBio); + transport->frontBio = 0; return status; } @@ -131,16 +132,16 @@ static int transport_bio_tsg_write(BIO* bio, const char* buf, int num) rdpTsg* tsg; tsg = (rdpTsg*) bio->ptr; - status = tsg_write(tsg, (BYTE*) buf, num); BIO_clear_retry_flags(bio); + status = tsg_write(tsg, (BYTE*) buf, num); + if (status > 0) + return status; if (status == 0) - { BIO_set_retry_write(bio); - } - return status < 0 ? 0 : num; + return -1; } static int transport_bio_tsg_read(BIO* bio, char* buf, int size) @@ -222,8 +223,13 @@ BIO_METHOD* BIO_s_tsg(void) return &transport_bio_tsg_methods; } + + BOOL transport_connect_tls(rdpTransport* transport) { + rdpSettings *settings = transport->settings; + rdpTls *targetTls; + BIO *targetBio; int tls_status; freerdp* instance; rdpContext* context; @@ -234,61 +240,33 @@ BOOL transport_connect_tls(rdpTransport* transport) if (transport->layer == TRANSPORT_LAYER_TSG) { transport->TsgTls = tls_new(transport->settings); - - transport->TsgTls->methods = BIO_s_tsg(); - transport->TsgTls->tsg = (void*) transport->tsg; - transport->layer = TRANSPORT_LAYER_TSG_TLS; - transport->TsgTls->hostname = transport->settings->ServerHostname; - transport->TsgTls->port = transport->settings->ServerPort; + targetTls = transport->TsgTls; + targetBio = transport->frontBio; + } + else + { + if (!transport->TlsIn) + transport->TlsIn = tls_new(settings); - if (transport->TsgTls->port == 0) - transport->TsgTls->port = 3389; + if (!transport->TlsOut) + transport->TlsOut = transport->TlsIn; - tls_status = tls_connect(transport->TsgTls); + targetTls = transport->TlsIn; + targetBio = transport->TcpIn->bufferedBio; - if (tls_status < 1) - { - if (tls_status < 0) - { - if (!connectErrorCode) - connectErrorCode = TLSCONNECTERROR; - - if (!freerdp_get_last_error(context)) - freerdp_set_last_error(context, FREERDP_ERROR_TLS_CONNECT_FAILED); - } - else - { - if (!freerdp_get_last_error(context)) - freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED); - } - - tls_free(transport->TsgTls); - transport->TsgTls = NULL; - - return FALSE; - } - - return TRUE; + transport->layer = TRANSPORT_LAYER_TLS; } - if (!transport->TlsIn) - transport->TlsIn = tls_new(transport->settings); - if (!transport->TlsOut) - transport->TlsOut = transport->TlsIn; + targetTls->hostname = settings->ServerHostname; + targetTls->port = settings->ServerPort; - transport->layer = TRANSPORT_LAYER_TLS; - transport->TlsIn->sockfd = transport->TcpIn->sockfd; + if (targetTls->port == 0) + targetTls->port = 3389; - transport->TlsIn->hostname = transport->settings->ServerHostname; - transport->TlsIn->port = transport->settings->ServerPort; - - if (transport->TlsIn->port == 0) - transport->TlsIn->port = 3389; - - tls_status = tls_connect(transport->TlsIn); + tls_status = tls_connect(targetTls, targetBio); if (tls_status < 1) { @@ -306,13 +284,13 @@ BOOL transport_connect_tls(rdpTransport* transport) freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED); } - tls_free(transport->TlsIn); - - if (transport->TlsIn == transport->TlsOut) - transport->TlsIn = transport->TlsOut = NULL; - else - transport->TlsIn = NULL; + return FALSE; + } + transport->frontBio = targetTls->bio; + if (!transport->frontBio) + { + fprintf(stderr, "%s: unable to prepend a filtering TLS bio", __FUNCTION__); return FALSE; } @@ -323,6 +301,7 @@ BOOL transport_connect_nla(rdpTransport* transport) { freerdp* instance; rdpSettings* settings; + rdpCredssp *credSsp; settings = transport->settings; instance = (freerdp*) settings->instance; @@ -338,16 +317,22 @@ BOOL transport_connect_nla(rdpTransport* transport) if (!transport->credssp) { transport->credssp = credssp_new(instance, transport, settings); + if (!transport->credssp) + return FALSE; + transport_set_nla_mode(transport, TRUE); if (settings->AuthenticationServiceClass) { transport->credssp->ServicePrincipalName = credssp_make_spn(settings->AuthenticationServiceClass, settings->ServerHostname); + if (!transport->credssp->ServicePrincipalName) + return FALSE; } } - if (credssp_authenticate(transport->credssp) < 0) + credSsp = transport->credssp; + if (credssp_authenticate(credSsp) < 0) { if (!connectErrorCode) connectErrorCode = AUTHENTICATIONERROR; @@ -361,14 +346,14 @@ BOOL transport_connect_nla(rdpTransport* transport) "If credentials are valid, the NTLMSSP implementation may be to blame.\n"); transport_set_nla_mode(transport, FALSE); - credssp_free(transport->credssp); + credssp_free(credSsp); transport->credssp = NULL; return FALSE; } transport_set_nla_mode(transport, FALSE); - credssp_free(transport->credssp); + credssp_free(credSsp); transport->credssp = NULL; return TRUE; @@ -380,38 +365,41 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16 int tls_status; freerdp* instance; rdpContext* context; + rdpSettings *settings = transport->settings; instance = (freerdp*) transport->settings->instance; context = instance->context; tsg = tsg_new(transport); + if (!tsg) + return FALSE; tsg->transport = transport; transport->tsg = tsg; transport->SplitInputOutput = TRUE; if (!transport->TlsIn) - transport->TlsIn = tls_new(transport->settings); - - transport->TlsIn->sockfd = transport->TcpIn->sockfd; - transport->TlsIn->hostname = transport->settings->GatewayHostname; - transport->TlsIn->port = transport->settings->GatewayPort; - - if (transport->TlsIn->port == 0) - transport->TlsIn->port = 443; - + { + transport->TlsIn = tls_new(settings); + if (!transport->TlsIn) + return FALSE; + } if (!transport->TlsOut) - transport->TlsOut = tls_new(transport->settings); + { + transport->TlsOut = tls_new(settings); + if (!transport->TlsOut) + return FALSE; + } - transport->TlsOut->sockfd = transport->TcpOut->sockfd; - transport->TlsOut->hostname = transport->settings->GatewayHostname; - transport->TlsOut->port = transport->settings->GatewayPort; + /* put a decent default value for gateway port */ + if (!settings->GatewayPort) + settings->GatewayPort = 443; - if (transport->TlsOut->port == 0) - transport->TlsOut->port = 443; + transport->TlsIn->hostname = transport->TlsOut->hostname = settings->GatewayHostname; + transport->TlsIn->port = transport->TlsOut->port = settings->GatewayPort; - tls_status = tls_connect(transport->TlsIn); + tls_status = tls_connect(transport->TlsIn, transport->TcpIn->bufferedBio); if (tls_status < 1) { if (tls_status < 0) @@ -428,8 +416,7 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16 return FALSE; } - tls_status = tls_connect(transport->TlsOut); - + tls_status = tls_connect(transport->TlsOut, transport->TcpOut->bufferedBio); if (tls_status < 1) { if (tls_status < 0) @@ -449,6 +436,8 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16 if (!tsg_connect(tsg, hostname, port)) return FALSE; + transport->frontBio = BIO_new(BIO_s_tsg()); + transport->frontBio->ptr = tsg; return TRUE; } @@ -462,15 +451,20 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por if (transport->GatewayEnabled) { transport->layer = TRANSPORT_LAYER_TSG; + transport->SplitInputOutput = TRUE; transport->TcpOut = tcp_new(settings); - status = tcp_connect(transport->TcpIn, settings->GatewayHostname, settings->GatewayPort); + if (!tcp_connect(transport->TcpIn, settings->GatewayHostname, settings->GatewayPort) || + !tcp_set_blocking_mode(transport->TcpIn, FALSE)) + return FALSE; - if (status) - status = tcp_connect(transport->TcpOut, settings->GatewayHostname, settings->GatewayPort); + if (!tcp_connect(transport->TcpOut, settings->GatewayHostname, settings->GatewayPort) || + !tcp_set_blocking_mode(transport->TcpOut, FALSE)) + return FALSE; - if (status) - status = transport_tsg_connect(transport, hostname, port); + if (!transport_tsg_connect(transport, hostname, port)) + return FALSE; + status = TRUE; } else { @@ -478,6 +472,7 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por transport->SplitInputOutput = FALSE; transport->TcpOut = transport->TcpIn; + transport->frontBio = transport->TcpIn->bufferedBio; } if (status) @@ -510,11 +505,11 @@ BOOL transport_accept_tls(rdpTransport* transport) transport->TlsOut = transport->TlsIn; transport->layer = TRANSPORT_LAYER_TLS; - transport->TlsIn->sockfd = transport->TcpIn->sockfd; - if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile)) + if (!tls_accept(transport->TlsIn, transport->TcpIn->bufferedBio, transport->settings->CertificateFile, transport->settings->PrivateKeyFile)) return FALSE; + transport->frontBio = transport->TlsIn->bio; return TRUE; } @@ -533,10 +528,10 @@ BOOL transport_accept_nla(rdpTransport* transport) transport->TlsOut = transport->TlsIn; transport->layer = TRANSPORT_LAYER_TLS; - transport->TlsIn->sockfd = transport->TcpIn->sockfd; - if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile)) + if (!tls_accept(transport->TlsIn, transport->TcpIn->bufferedBio, settings->CertificateFile, settings->PrivateKeyFile)) return FALSE; + transport->frontBio = transport->TlsIn->bio; /* Network Level Authentication */ @@ -630,56 +625,131 @@ UINT32 nla_header_length(wStream* s) return length; } +static int transport_wait_for_read(rdpTransport* transport) +{ + struct timeval tv; + fd_set rset, wset; + fd_set *rsetPtr = NULL, *wsetPtr = NULL; + rdpTcp *tcpIn; + + tcpIn = transport->TcpIn; + if (tcpIn->readBlocked) + { + rsetPtr = &rset; + FD_ZERO(rsetPtr); + FD_SET(tcpIn->sockfd, rsetPtr); + } + else if (tcpIn->writeBlocked) + { + wsetPtr = &wset; + FD_ZERO(wsetPtr); + FD_SET(tcpIn->sockfd, wsetPtr); + } + + if (!wsetPtr && !rsetPtr) + { + USleep(1000); + return 0; + } + + tv.tv_sec = 0; + tv.tv_usec = 1000; + + return select(tcpIn->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv); +} + + +static int transport_wait_for_write(rdpTransport* transport) +{ + struct timeval tv; + fd_set rset, wset; + fd_set *rsetPtr = NULL, *wsetPtr = NULL; + rdpTcp *tcpOut; + + tcpOut = transport->SplitInputOutput ? transport->TcpOut : transport->TcpIn; + if (tcpOut->writeBlocked) + { + wsetPtr = &wset; + FD_ZERO(wsetPtr); + FD_SET(tcpOut->sockfd, wsetPtr); + } + else if (tcpOut->readBlocked) + { + rsetPtr = &rset; + FD_ZERO(rsetPtr); + FD_SET(tcpOut->sockfd, rsetPtr); + } + + if (!wsetPtr && !rsetPtr) + { + USleep(1000); + return 0; + } + + tv.tv_sec = 0; + tv.tv_usec = 1000; + + return select(tcpOut->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv); +} + + int transport_read_layer(rdpTransport* transport, BYTE* data, int bytes) { int read = 0; int status = -1; + while (read < bytes) { - if (transport->layer == TRANSPORT_LAYER_TLS) - status = tls_read(transport->TlsIn, data + read, bytes - read); - else if (transport->layer == TRANSPORT_LAYER_TCP) - status = tcp_read(transport->TcpIn, data + read, bytes - read); - else if (transport->layer == TRANSPORT_LAYER_TSG) - status = tsg_read(transport->tsg, data + read, bytes - read); - else if (transport->layer == TRANSPORT_LAYER_TSG_TLS) { - status = tls_read(transport->TsgTls, data + read, bytes - read); + status = BIO_read(transport->frontBio, data + read, bytes - read); + + if (!status) + { + transport->layer = TRANSPORT_LAYER_CLOSED; + return -1; } - /* blocking means that we can't continue until this is read */ - - if (!transport->blocking) - return status; - if (status < 0) { - /* A read error indicates that the peer has dropped the connection */ - transport->layer = TRANSPORT_LAYER_CLOSED; - return status; + if (!BIO_should_retry(transport->frontBio)) + { + /* something unexpected happened, let's close */ + transport->layer = TRANSPORT_LAYER_CLOSED; + return -1; + } + + /* non blocking will survive a partial read */ + if (!transport->blocking) + return read; + + /* blocking means that we can't continue until we have read the number of + * requested bytes */ + if (transport_wait_for_read(transport) < 0) + { + fprintf(stderr, "%s: error when selecting for read\n", __FUNCTION__); + return -1; + } + continue; } +#ifdef HAVE_VALGRIND_MEMCHECK_H + VALGRIND_MAKE_MEM_DEFINED(data + read, bytes - read); +#endif + read += status; - - if (status == 0) - { - /* - * instead of sleeping, we should wait timeout on the - * socket but this only happens on initial connection - */ - USleep(transport->SleepInterval); - } } return read; } + + int transport_read(rdpTransport* transport, wStream* s) { int status; int position; int pduLength; - BYTE header[4]; + BYTE *header; int transport_status; position = 0; @@ -710,7 +780,7 @@ int transport_read(rdpTransport* transport, wStream* s) position += status; } - CopyMemory(header, Stream_Buffer(s), 4); /* peek at first 4 bytes */ + header = Stream_Buffer(s); /* if header is present, read exactly one PDU */ @@ -802,6 +872,8 @@ static int transport_read_nonblocking(rdpTransport* transport) return status; } +BOOL transport_bio_buffered_drain(BIO *bio); + int transport_write(rdpTransport* transport, wStream* s) { int length; @@ -827,36 +899,48 @@ int transport_write(rdpTransport* transport, wStream* s) while (length > 0) { - if (transport->layer == TRANSPORT_LAYER_TLS) - status = tls_write(transport->TlsOut, Stream_Pointer(s), length); - else if (transport->layer == TRANSPORT_LAYER_TCP) - status = tcp_write(transport->TcpOut, Stream_Pointer(s), length); - else if (transport->layer == TRANSPORT_LAYER_TSG) - status = tsg_write(transport->tsg, Stream_Pointer(s), length); - else if (transport->layer == TRANSPORT_LAYER_TSG_TLS) - status = tls_write(transport->TsgTls, Stream_Pointer(s), length); + status = BIO_write(transport->frontBio, Stream_Pointer(s), length); - if (status < 0) - break; /* error occurred */ - - if (status == 0) + if (status <= 0) { - /* when sending is blocked in nonblocking mode, the receiving buffer should be checked */ - if (!transport->blocking) - { - /* and in case we do have buffered some data, we set the event so next loop will get it */ - if (transport_read_nonblocking(transport) > 0) - SetEvent(transport->ReceiveEvent); - } + /* the buffered BIO that is at the end of the chain always says OK for writing, + * so a retry means that for any reason we need to read. The most probable + * is a SSL or TSG BIO in the chain. + */ + if (!BIO_should_retry(transport->frontBio)) + return status; - if (transport->layer == TRANSPORT_LAYER_TLS) - tls_wait_write(transport->TlsOut); - else if (transport->layer == TRANSPORT_LAYER_TCP) - tcp_wait_write(transport->TcpOut); - else if (transport->layer == TRANSPORT_LAYER_TSG_TLS) - tls_wait_write(transport->TsgTls); - else - USleep(transport->SleepInterval); + /* non-blocking can live with blocked IOs */ + if (!transport->blocking) + return status; + + if (transport_wait_for_write(transport) < 0) + { + fprintf(stderr, "%s: error when selecting for write\n", __FUNCTION__); + return -1; + } + continue; + } + + if (transport->blocking || transport->settings->WaitForOutputBufferFlush) + { + /* blocking transport, we must ensure the write buffer is really empty */ + rdpTcp *out = transport->TcpOut; + + while (out->writeBlocked) + { + if (transport_wait_for_write(transport) < 0) + { + fprintf(stderr, "%s: error when selecting for write\n", __FUNCTION__); + return -1; + } + + if (!transport_bio_buffered_drain(out->bufferedBio)) + { + fprintf(stderr, "%s: error when draining outputBuffer\n", __FUNCTION__); + return -1; + } + } } length -= status; @@ -945,6 +1029,38 @@ void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD* } } +BOOL tranport_is_write_blocked(rdpTransport* transport) +{ + if (transport->TcpIn->writeBlocked) + return TRUE; + + return transport->SplitInputOutput && + transport->TcpOut && + transport->TcpOut->writeBlocked; +} + +int tranport_drain_output_buffer(rdpTransport* transport) +{ + BOOL ret = FALSE; + + /* First try to send some accumulated bytes in the send buffer */ + if (transport->TcpIn->writeBlocked) + { + if (!transport_bio_buffered_drain(transport->TcpIn->bufferedBio)) + return -1; + ret |= transport->TcpIn->writeBlocked; + } + + if (transport->SplitInputOutput && transport->TcpOut && transport->TcpOut->writeBlocked) + { + if (!transport_bio_buffered_drain(transport->TcpOut->bufferedBio)) + return -1; + ret |= transport->TcpOut->writeBlocked; + } + + return ret; +} + int transport_check_fds(rdpTransport* transport) { int pos; @@ -1079,15 +1195,14 @@ int transport_check_fds(rdpTransport* transport) recv_status = transport->ReceiveCallback(transport, received, transport->ReceiveExtra); - Stream_Release(received); - - if (recv_status < 0) - return -1; - if (recv_status == 1) { return 1; /* session redirection */ } + Stream_Release(received); + + if (recv_status < 0) + return -1; } return 0; @@ -1198,80 +1313,107 @@ rdpTransport* transport_new(rdpSettings* settings) { rdpTransport* transport; - transport = (rdpTransport*) malloc(sizeof(rdpTransport)); + transport = (rdpTransport *)calloc(1, sizeof(rdpTransport)); + if (!transport) + return NULL; - if (transport) - { - ZeroMemory(transport, sizeof(rdpTransport)); + WLog_Init(); + transport->log = WLog_Get("com.freerdp.core.transport"); + if (!transport->log) + goto out_free; - WLog_Init(); - transport->log = WLog_Get("com.freerdp.core.transport"); + transport->TcpIn = tcp_new(settings); + if (!transport->TcpIn) + goto out_free; - transport->TcpIn = tcp_new(settings); + transport->settings = settings; - transport->settings = settings; + /* a small 0.1ms delay when transport is blocking. */ + transport->SleepInterval = 100; - /* a small 0.1ms delay when transport is blocking. */ - transport->SleepInterval = 100; + transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE); + if (!transport->ReceivePool) + goto out_free_tcpin; - transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE); + /* receive buffer for non-blocking read. */ + transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0); + if (!transport->ReceiveBuffer) + goto out_free_receivepool; - /* receive buffer for non-blocking read. */ - transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0); - transport->ReceiveEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + transport->ReceiveEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + if (!transport->ReceiveEvent || transport->ReceiveEvent == INVALID_HANDLE_VALUE) + goto out_free_receivebuffer; - transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + if (!transport->connectedEvent || transport->connectedEvent == INVALID_HANDLE_VALUE) + goto out_free_receiveEvent; - transport->blocking = TRUE; - transport->GatewayEnabled = FALSE; + transport->blocking = TRUE; + transport->GatewayEnabled = FALSE; + transport->layer = TRANSPORT_LAYER_TCP; - InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000); - InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000); - - transport->layer = TRANSPORT_LAYER_TCP; - } + if (!InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000)) + goto out_free_connectedEvent; + if (!InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000)) + goto out_free_readlock; return transport; + +out_free_readlock: + DeleteCriticalSection(&(transport->ReadLock)); +out_free_connectedEvent: + CloseHandle(transport->connectedEvent); +out_free_receiveEvent: + CloseHandle(transport->ReceiveEvent); +out_free_receivebuffer: + StreamPool_Return(transport->ReceivePool, transport->ReceiveBuffer); +out_free_receivepool: + StreamPool_Free(transport->ReceivePool); +out_free_tcpin: + tcp_free(transport->TcpIn); +out_free: + free(transport); + return NULL; } void transport_free(rdpTransport* transport) { - if (transport) - { - transport_stop(transport); + if (!transport) + return; - if (transport->ReceiveBuffer) - Stream_Release(transport->ReceiveBuffer); + transport_stop(transport); - StreamPool_Free(transport->ReceivePool); + if (transport->ReceiveBuffer) + Stream_Release(transport->ReceiveBuffer); - CloseHandle(transport->ReceiveEvent); - CloseHandle(transport->connectedEvent); + StreamPool_Free(transport->ReceivePool); - if (transport->TlsIn) - tls_free(transport->TlsIn); + CloseHandle(transport->ReceiveEvent); + CloseHandle(transport->connectedEvent); - if (transport->TlsOut != transport->TlsIn) - tls_free(transport->TlsOut); + if (transport->TlsIn) + tls_free(transport->TlsIn); - transport->TlsIn = NULL; - transport->TlsOut = NULL; + if (transport->TlsOut != transport->TlsIn) + tls_free(transport->TlsOut); - if (transport->TcpIn) - tcp_free(transport->TcpIn); + transport->TlsIn = NULL; + transport->TlsOut = NULL; - if (transport->TcpOut != transport->TcpIn) - tcp_free(transport->TcpOut); + if (transport->TcpIn) + tcp_free(transport->TcpIn); - transport->TcpIn = NULL; - transport->TcpOut = NULL; + if (transport->TcpOut != transport->TcpIn) + tcp_free(transport->TcpOut); - tsg_free(transport->tsg); - transport->tsg = NULL; + transport->TcpIn = NULL; + transport->TcpOut = NULL; - DeleteCriticalSection(&(transport->ReadLock)); - DeleteCriticalSection(&(transport->WriteLock)); + tsg_free(transport->tsg); + transport->tsg = NULL; - free(transport); - } + DeleteCriticalSection(&(transport->ReadLock)); + DeleteCriticalSection(&(transport->WriteLock)); + + free(transport); } diff --git a/libfreerdp/core/transport.h b/libfreerdp/core/transport.h index b8834ce7a..4e9f7e5a4 100644 --- a/libfreerdp/core/transport.h +++ b/libfreerdp/core/transport.h @@ -49,11 +49,13 @@ typedef struct rdp_transport rdpTransport; #include #include + typedef int (*TransportRecv) (rdpTransport* transport, wStream* stream, void* extra); struct rdp_transport { TRANSPORT_LAYER layer; + BIO *frontBio; rdpTsg* tsg; rdpTcp* TcpIn; rdpTcp* TcpOut; @@ -102,6 +104,8 @@ BOOL transport_set_blocking_mode(rdpTransport* transport, BOOL blocking); void transport_set_gateway_enabled(rdpTransport* transport, BOOL GatewayEnabled); void transport_set_nla_mode(rdpTransport* transport, BOOL NlaMode); void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD* count); +BOOL tranport_is_write_blocked(rdpTransport* transport); +int tranport_drain_output_buffer(rdpTransport* transport); wStream* transport_receive_pool_take(rdpTransport* transport); int transport_receive_pool_return(rdpTransport* transport, wStream* pdu); diff --git a/libfreerdp/core/update.c b/libfreerdp/core/update.c index b322fe753..15c5b9cf5 100644 --- a/libfreerdp/core/update.c +++ b/libfreerdp/core/update.c @@ -544,7 +544,7 @@ static void update_end_paint(rdpContext* context) if (update->numberOrders > 0) { - printf("Sending %d orders\n", update->numberOrders); + fprintf(stderr, "%s: sending %d orders\n", __FUNCTION__, update->numberOrders); fastpath_send_update_pdu(context->rdp->fastpath, FASTPATH_UPDATETYPE_ORDERS, s); } diff --git a/libfreerdp/crypto/test/.gitignore b/libfreerdp/crypto/test/.gitignore new file mode 100644 index 000000000..d425a5a86 --- /dev/null +++ b/libfreerdp/crypto/test/.gitignore @@ -0,0 +1 @@ +TestFreeRDPCrypto.c diff --git a/libfreerdp/crypto/test/TestBase64.c b/libfreerdp/crypto/test/TestBase64.c index 5c32ea422..a50b6b7b0 100644 --- a/libfreerdp/crypto/test/TestBase64.c +++ b/libfreerdp/crypto/test/TestBase64.c @@ -1,24 +1,20 @@ /** - * Copyright © 2014 Thincast Technologies GmbH - * Copyright © 2014 Hardening + * FreeRDP: A Remote Desktop Protocol Implementation * - * Permission to use, copy, modify, distribute, and sell this software and - * its documentation for any purpose is hereby granted without fee, provided - * that the above copyright notice appear in all copies and that both that - * copyright notice and this permission notice appear in supporting - * documentation, and that the name of the copyright holders not be used in - * advertising or publicity pertaining to distribution of the software - * without specific, written prior permission. The copyright holders make - * no representations about the suitability of this software for any - * purpose. It is provided "as is" without express or implied warranty. + * Copyright 2014 Thincast Technologies GmbH + * Copyright 2014 Hardening * - * THE COPYRIGHT HOLDERS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS - * SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS, IN NO EVENT SHALL THE COPYRIGHT HOLDERS BE LIABLE FOR ANY - * SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER - * RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF - * CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN - * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include diff --git a/libfreerdp/crypto/tls.c b/libfreerdp/crypto/tls.c index 52c217782..016584fcc 100644 --- a/libfreerdp/crypto/tls.c +++ b/libfreerdp/crypto/tls.c @@ -28,34 +28,35 @@ #include #include +#include #include - -#ifdef HAVE_VALGRIND_MEMCHECK_H -#include -#endif +#include "../core/tcp.h" static CryptoCert tls_get_certificate(rdpTls* tls, BOOL peer) { CryptoCert cert; - X509* server_cert; + X509* remote_cert; if (peer) - server_cert = SSL_get_peer_certificate(tls->ssl); + remote_cert = SSL_get_peer_certificate(tls->ssl); else - server_cert = SSL_get_certificate(tls->ssl); + remote_cert = SSL_get_certificate(tls->ssl); - if (!server_cert) + if (!remote_cert) { - fprintf(stderr, "tls_get_certificate: failed to get the server TLS certificate\n"); - cert = NULL; - } - else - { - cert = malloc(sizeof(*cert)); - cert->px509 = server_cert; + fprintf(stderr, "%s: failed to get the server TLS certificate\n", __FUNCTION__); + return NULL; } + cert = malloc(sizeof(*cert)); + if (!cert) + { + X509_free(remote_cert); + return NULL; + } + + cert->px509 = remote_cert; return cert; } @@ -83,12 +84,14 @@ SecPkgContext_Bindings* tls_get_channel_bindings(X509* cert) PrefixLength = strlen(TLS_SERVER_END_POINT); ChannelBindingTokenLength = PrefixLength + CertificateHashLength; - ContextBindings = (SecPkgContext_Bindings*) malloc(sizeof(SecPkgContext_Bindings)); - ZeroMemory(ContextBindings, sizeof(SecPkgContext_Bindings)); + ContextBindings = (SecPkgContext_Bindings*) calloc(1, sizeof(SecPkgContext_Bindings)); + if (!ContextBindings) + return NULL; ContextBindings->BindingsLength = sizeof(SEC_CHANNEL_BINDINGS) + ChannelBindingTokenLength; - ChannelBindings = (SEC_CHANNEL_BINDINGS*) malloc(ContextBindings->BindingsLength); - ZeroMemory(ChannelBindings, ContextBindings->BindingsLength); + ChannelBindings = (SEC_CHANNEL_BINDINGS*) calloc(1, ContextBindings->BindingsLength); + if (!ChannelBindings) + goto out_free; ContextBindings->Bindings = ChannelBindings; ChannelBindings->cbApplicationDataLength = ChannelBindingTokenLength; @@ -99,32 +102,121 @@ SecPkgContext_Bindings* tls_get_channel_bindings(X509* cert) CopyMemory(&ChannelBindingToken[PrefixLength], CertificateHash, CertificateHashLength); return ContextBindings; + +out_free: + free(ContextBindings); + return NULL; } -static void tls_ssl_info_callback(const SSL* ssl, int type, int val) + +BOOL tls_prepare(rdpTls* tls, BIO *underlying, const SSL_METHOD *method, int options, BOOL clientMode) { - if (type & SSL_CB_HANDSHAKE_START) - { - - } -} - -int tls_connect(rdpTls* tls) -{ - CryptoCert cert; - long options = 0; - int verify_status; - int connection_status; - - tls->ctx = SSL_CTX_new(TLSv1_client_method()); - + tls->ctx = SSL_CTX_new(method); if (!tls->ctx) { - fprintf(stderr, "SSL_CTX_new failed\n"); + fprintf(stderr, "%s: SSL_CTX_new failed\n", __FUNCTION__); + return FALSE; + } + + SSL_CTX_set_mode(tls->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE); + + SSL_CTX_set_options(tls->ctx, options); + SSL_CTX_set_read_ahead(tls->ctx, 1); + + tls->bio = BIO_new_ssl(tls->ctx, clientMode); + if (BIO_get_ssl(tls->bio, &tls->ssl) < 0) + { + fprintf(stderr, "%s: unable to retrieve the SSL of the connection\n", __FUNCTION__); + return FALSE; + } + + BIO_push(tls->bio, underlying); + return TRUE; +} + +int tls_do_handshake(rdpTls* tls, BOOL clientMode) +{ + CryptoCert cert; + int verify_status, status; + + do + { + struct timeval tv; + fd_set rset; + int fd; + + status = BIO_do_handshake(tls->bio); + if (status == 1) + break; + if (!BIO_should_retry(tls->bio)) + return -1; + + /* we select() only for read even if we should test both read and write + * depending of what have blocked */ + FD_ZERO(&rset); + + fd = BIO_get_fd(tls->bio, NULL); + if (fd < 0) + { + fprintf(stderr, "%s: unable to retrieve BIO fd\n", __FUNCTION__); + return -1; + } + + FD_SET(fd, &rset); + tv.tv_sec = 0; + tv.tv_usec = 10 * 1000; /* 10ms */ + + status = select(fd + 1, &rset, NULL, NULL, &tv); + if (status < 0) + { + fprintf(stderr, "%s: error during select()\n", __FUNCTION__); + return -1; + } + } + while (TRUE); + + if (!clientMode) + return 1; + + cert = tls_get_certificate(tls, clientMode); + if (!cert) + { + fprintf(stderr, "%s: tls_get_certificate failed to return the server certificate.\n", __FUNCTION__); return -1; } - //SSL_CTX_set_mode(tls->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE); + tls->Bindings = tls_get_channel_bindings(cert->px509); + if (!tls->Bindings) + { + fprintf(stderr, "%s: unable to retrieve bindings\n", __FUNCTION__); + return -1; + } + + if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength)) + { + fprintf(stderr, "%s: crypto_cert_get_public_key failed to return the server public key.\n", __FUNCTION__); + tls_free_certificate(cert); + return -1; + } + + verify_status = tls_verify_certificate(tls, cert, tls->hostname, tls->port); + + if (verify_status < 1) + { + fprintf(stderr, "%s: certificate not trusted, aborting.\n", __FUNCTION__); + tls_disconnect(tls); + tls_free_certificate(cert); + return 0; + } + + tls_free_certificate(cert); + + return verify_status; +} + +int tls_connect(rdpTls* tls, BIO *underlying) +{ + int options = 0; /** * SSL_OP_NO_COMPRESSION: @@ -138,7 +230,7 @@ int tls_connect(rdpTls* tls) #ifdef SSL_OP_NO_COMPRESSION options |= SSL_OP_NO_COMPRESSION; #endif - + /** * SSL_OP_TLS_BLOCK_PADDING_BUG: * @@ -155,96 +247,19 @@ int tls_connect(rdpTls* tls) */ options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS; - SSL_CTX_set_options(tls->ctx, options); + if (!tls_prepare(tls, underlying, TLSv1_client_method(), options, TRUE)) + return FALSE; - tls->ssl = SSL_new(tls->ctx); - - if (!tls->ssl) - { - fprintf(stderr, "SSL_new failed\n"); - return -1; - } - - if (tls->tsg) - { - tls->bio = BIO_new(tls->methods); - - if (!tls->bio) - { - fprintf(stderr, "BIO_new failed\n"); - return -1; - } - - tls->bio->ptr = tls->tsg; - - SSL_set_bio(tls->ssl, tls->bio, tls->bio); - - SSL_CTX_set_info_callback(tls->ctx, tls_ssl_info_callback); - } - else - { - if (SSL_set_fd(tls->ssl, tls->sockfd) < 1) - { - fprintf(stderr, "SSL_set_fd failed\n"); - return -1; - } - } - - connection_status = SSL_connect(tls->ssl); - - if (connection_status <= 0) - { - if (tls_print_error("SSL_connect", tls->ssl, connection_status)) - { - return -1; - } - } - - cert = tls_get_certificate(tls, TRUE); - - if (!cert) - { - fprintf(stderr, "tls_connect: tls_get_certificate failed to return the server certificate.\n"); - return -1; - } - - tls->Bindings = tls_get_channel_bindings(cert->px509); - - if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength)) - { - fprintf(stderr, "tls_connect: crypto_cert_get_public_key failed to return the server public key.\n"); - tls_free_certificate(cert); - return -1; - } - - verify_status = tls_verify_certificate(tls, cert, tls->hostname, tls->port); - - if (verify_status < 1) - { - fprintf(stderr, "tls_connect: certificate not trusted, aborting.\n"); - tls_disconnect(tls); - } - - tls_free_certificate(cert); - - return verify_status; + return tls_do_handshake(tls, TRUE); } -BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file) + + +BOOL tls_accept(rdpTls* tls, BIO *underlying, const char* cert_file, const char* privatekey_file) { - CryptoCert cert; long options = 0; - int connection_status; - tls->ctx = SSL_CTX_new(SSLv23_server_method()); - - if (tls->ctx == NULL) - { - fprintf(stderr, "SSL_CTX_new failed\n"); - return FALSE; - } - - /* + /** * SSL_OP_NO_SSLv2: * * We only want SSLv3 and TLSv1, so disable SSLv2. @@ -281,80 +296,23 @@ BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file) */ options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS; - SSL_CTX_set_options(tls->ctx, options); - - if (SSL_CTX_use_RSAPrivateKey_file(tls->ctx, privatekey_file, SSL_FILETYPE_PEM) <= 0) - { - fprintf(stderr, "SSL_CTX_use_RSAPrivateKey_file failed\n"); - fprintf(stderr, "PrivateKeyFile: %s\n", privatekey_file); + if (!tls_prepare(tls, underlying, SSLv23_server_method(), options, FALSE)) return FALSE; - } - tls->ssl = SSL_new(tls->ctx); - - if (!tls->ssl) + if (SSL_use_RSAPrivateKey_file(tls->ssl, privatekey_file, SSL_FILETYPE_PEM) <= 0) { - fprintf(stderr, "SSL_new failed\n"); + fprintf(stderr, "%s: SSL_CTX_use_RSAPrivateKey_file failed\n", __FUNCTION__); + fprintf(stderr, "PrivateKeyFile: %s\n", privatekey_file); return FALSE; } if (SSL_use_certificate_file(tls->ssl, cert_file, SSL_FILETYPE_PEM) <= 0) { - fprintf(stderr, "SSL_use_certificate_file failed\n"); + fprintf(stderr, "%s: SSL_use_certificate_file failed\n", __FUNCTION__); return FALSE; } - if (SSL_set_fd(tls->ssl, tls->sockfd) < 1) - { - fprintf(stderr, "SSL_set_fd failed\n"); - return FALSE; - } - - while (1) - { - connection_status = SSL_accept(tls->ssl); - - if (connection_status <= 0) - { - switch (SSL_get_error(tls->ssl, connection_status)) - { - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_WRITE: - break; - - default: - if (tls_print_error("SSL_accept", tls->ssl, connection_status)) - return FALSE; - break; - - } - } - else - { - break; - } - } - - cert = tls_get_certificate(tls, FALSE); - - if (!cert) - { - fprintf(stderr, "tls_connect: tls_get_certificate failed to return the server certificate.\n"); - return FALSE; - } - - if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength)) - { - fprintf(stderr, "tls_connect: crypto_cert_get_public_key failed to return the server public key.\n"); - tls_free_certificate(cert); - return FALSE; - } - - free(cert); - - fprintf(stderr, "TLS connection accepted\n"); - - return TRUE; + return tls_do_handshake(tls, FALSE) > 0; } BOOL tls_disconnect(rdpTls* tls) @@ -362,256 +320,161 @@ BOOL tls_disconnect(rdpTls* tls) if (!tls) return FALSE; - if (tls->ssl) + if (!tls->ssl) + return TRUE; + + if (tls->alertDescription != TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY) { - if (tls->alertDescription != TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY) - { - /** - * OpenSSL doesn't really expose an API for sending a TLS alert manually. - * - * The following code disables the sending of the default "close notify" - * and then proceeds to force sending a custom TLS alert before shutting down. - * - * Manually sending a TLS alert is necessary in certain cases, - * like when server-side NLA results in an authentication failure. - */ + /** + * OpenSSL doesn't really expose an API for sending a TLS alert manually. + * + * The following code disables the sending of the default "close notify" + * and then proceeds to force sending a custom TLS alert before shutting down. + * + * Manually sending a TLS alert is necessary in certain cases, + * like when server-side NLA results in an authentication failure. + */ - SSL_set_quiet_shutdown(tls->ssl, 1); + SSL_set_quiet_shutdown(tls->ssl, 1); - if ((tls->alertLevel == TLS_ALERT_LEVEL_FATAL) && (tls->ssl->session)) - SSL_CTX_remove_session(tls->ssl->ctx, tls->ssl->session); + if ((tls->alertLevel == TLS_ALERT_LEVEL_FATAL) && (tls->ssl->session)) + SSL_CTX_remove_session(tls->ssl->ctx, tls->ssl->session); - tls->ssl->s3->alert_dispatch = 1; - tls->ssl->s3->send_alert[0] = tls->alertLevel; - tls->ssl->s3->send_alert[1] = tls->alertDescription; + tls->ssl->s3->alert_dispatch = 1; + tls->ssl->s3->send_alert[0] = tls->alertLevel; + tls->ssl->s3->send_alert[1] = tls->alertDescription; - if (tls->ssl->s3->wbuf.left == 0) - tls->ssl->method->ssl_dispatch_alert(tls->ssl); + if (tls->ssl->s3->wbuf.left == 0) + tls->ssl->method->ssl_dispatch_alert(tls->ssl); - SSL_shutdown(tls->ssl); - } - else - { - SSL_shutdown(tls->ssl); - } + SSL_shutdown(tls->ssl); + } + else + { + SSL_shutdown(tls->ssl); } return TRUE; } -int tls_read(rdpTls* tls, BYTE* data, int length) + +BIO *findBufferedBio(BIO *front) { - int error; - int status; + BIO *ret = front; - if (!tls) - return -1; - - if (!tls->ssl) - return -1; - - status = SSL_read(tls->ssl, data, length); - - if (status == 0) + while (ret) { - return -1; /* peer disconnected */ + if (BIO_method_type(ret) == BIO_TYPE_BUFFERED) + return ret; + ret = ret->next_bio; } - if (status <= 0) - { - error = SSL_get_error(tls->ssl, status); - - //fprintf(stderr, "tls_read: length: %d status: %d error: 0x%08X\n", - // length, status, error); - - switch (error) - { - case SSL_ERROR_NONE: - break; - - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_WRITE: - status = 0; - break; - - case SSL_ERROR_SYSCALL: -#ifdef _WIN32 - if (WSAGetLastError() == WSAEWOULDBLOCK) -#else - if ((errno == EAGAIN) || (errno == 0)) -#endif - { - status = 0; - } - else - { - if (tls_print_error("SSL_read", tls->ssl, status)) - { - status = -1; - } - else - { - status = 0; - } - } - break; - - default: - if (tls_print_error("SSL_read", tls->ssl, status)) - { - status = -1; - } - else - { - status = 0; - } - break; - } - } - -#ifdef HAVE_VALGRIND_MEMCHECK_H - VALGRIND_MAKE_MEM_DEFINED(data, status); -#endif - - return status; + return ret; } -int tls_write(rdpTls* tls, BYTE* data, int length) +int tls_write_all(rdpTls* tls, const BYTE* data, int length) { - int error; - int status; + int status, nchunks, commitedBytes; + rdpTcp *tcp; + fd_set rset, wset; + fd_set *rsetPtr, *wsetPtr; + struct timeval tv; + BIO *bio = tls->bio; + DataChunk chunks[2]; - if (!tls) - return -1; - - if (!tls->ssl) - return -1; - - status = SSL_write(tls->ssl, data, length); - - if (status == 0) + BIO *bufferedBio = findBufferedBio(bio); + if (!bufferedBio) { - return -1; /* peer disconnected */ + fprintf(stderr, "%s: error unable to retrieve the bufferedBio in the BIO chain\n", __FUNCTION__); + return -1; } - if (status < 0) - { - error = SSL_get_error(tls->ssl, status); - - //fprintf(stderr, "tls_write: length: %d status: %d error: 0x%08X\n", length, status, error); - - switch (error) - { - case SSL_ERROR_NONE: - break; - - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_WRITE: - status = 0; - break; - - case SSL_ERROR_SYSCALL: - if (errno == EAGAIN) - { - status = 0; - } - else - { - tls_print_error("SSL_write", tls->ssl, status); - status = -1; - } - break; - - default: - tls_print_error("SSL_write", tls->ssl, status); - status = -1; - break; - } - } - - return status; -} - -int tls_write_all(rdpTls* tls, BYTE* data, int length) -{ - int status; - int sent = 0; + tcp = (rdpTcp *)bufferedBio->ptr; do { - status = tls_write(tls, &data[sent], length - sent); - + status = BIO_write(bio, data, length); + /*fprintf(stderr, "%s: BIO_write(len=%d) = %d (retry=%d)\n", __FUNCTION__, length, status, BIO_should_retry(bio));*/ if (status > 0) - sent += status; - else if (status == 0) - tls_wait_write(tls); - - if (sent >= length) break; + + if (!BIO_should_retry(bio)) + return -1; + + /* we try to handle SSL want_read and want_write nicely */ + rsetPtr = wsetPtr = 0; + if (tcp->writeBlocked) + { + wsetPtr = &wset; + FD_ZERO(&wset); + FD_SET(tcp->sockfd, &wset); + } + else if (tcp->readBlocked) + { + rsetPtr = &rset; + FD_ZERO(&rset); + FD_SET(tcp->sockfd, &rset); + } + else + { + fprintf(stderr, "%s: weird we're blocked but the underlying is not read or write blocked !\n", __FUNCTION__); + USleep(10); + continue; + } + + tv.tv_sec = 0; + tv.tv_usec = 100 * 1000; + + status = select(tcp->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv); + if (status < 0) + return -1; } - while (status >= 0); + while (TRUE); - if (status > 0) - return length; - else - return status; -} - -int tls_wait_read(rdpTls* tls) -{ - return freerdp_tcp_wait_read(tls->sockfd); -} - -int tls_wait_write(rdpTls* tls) -{ - return freerdp_tcp_wait_write(tls->sockfd); -} - -static void tls_errors(const char *prefix) -{ - unsigned long error; - - while ((error = ERR_get_error()) != 0) - fprintf(stderr, "%s: %s\n", prefix, ERR_error_string(error, NULL)); -} - -BOOL tls_print_error(char* func, SSL* connection, int value) -{ - switch (SSL_get_error(connection, value)) + /* make sure the output buffer is empty */ + commitedBytes = 0; + while ((nchunks = ringbuffer_peek(&tcp->xmitBuffer, chunks, ringbuffer_used(&tcp->xmitBuffer)))) { - case SSL_ERROR_ZERO_RETURN: - fprintf(stderr, "%s: Server closed TLS connection\n", func); - return TRUE; + int i; - case SSL_ERROR_WANT_READ: - fprintf(stderr, "%s: SSL_ERROR_WANT_READ\n", func); - return FALSE; + for (i = 0; i < nchunks; i++) + { + while (chunks[i].size) + { + status = BIO_write(tcp->socketBio, chunks[i].data, chunks[i].size); + if (status > 0) + { + chunks[i].size -= status; + chunks[i].data += status; + commitedBytes += status; + continue; + } - case SSL_ERROR_WANT_WRITE: - fprintf(stderr, "%s: SSL_ERROR_WANT_WRITE\n", func); - return FALSE; + if (!BIO_should_retry(tcp->socketBio)) + goto out_fail; + FD_ZERO(&rset); + FD_SET(tcp->sockfd, &rset); + tv.tv_sec = 0; + tv.tv_usec = 100 * 1000; - case SSL_ERROR_SYSCALL: -#ifdef _WIN32 - fprintf(stderr, "%s: I/O error: %d\n", func, WSAGetLastError()); -#else - fprintf(stderr, "%s: I/O error: %s (%d)\n", func, strerror(errno), errno); -#endif - tls_errors(func); - return TRUE; + status = select(tcp->sockfd + 1, &rset, NULL, NULL, &tv); + if (status < 0) + goto out_fail; + } - case SSL_ERROR_SSL: - fprintf(stderr, "%s: Failure in SSL library (protocol error?)\n", func); - tls_errors(func); - return TRUE; - - default: - fprintf(stderr, "%s: Unknown error\n", func); - tls_errors(func); - return TRUE; + } } + + ringbuffer_commit_read_bytes(&tcp->xmitBuffer, commitedBytes); + return length; + +out_fail: + ringbuffer_commit_read_bytes(&tcp->xmitBuffer, commitedBytes); + return -1; } + + int tls_set_alert_code(rdpTls* tls, int level, int description) { tls->alertLevel = level; @@ -672,7 +535,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por if (!bio) { - fprintf(stderr, "tls_verify_certificate: BIO_new() failure\n"); + fprintf(stderr, "%s: BIO_new() failure\n", __FUNCTION__); return -1; } @@ -680,7 +543,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por if (status < 0) { - fprintf(stderr, "tls_verify_certificate: PEM_write_bio_X509 failure: %d\n", status); + fprintf(stderr, "%s: PEM_write_bio_X509 failure: %d\n", __FUNCTION__, status); return -1; } @@ -692,7 +555,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por if (status < 0) { - fprintf(stderr, "tls_verify_certificate: failed to read certificate\n"); + fprintf(stderr, "%s: failed to read certificate\n", __FUNCTION__); return -1; } @@ -713,7 +576,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por if (status < 0) { - fprintf(stderr, "tls_verify_certificate: failed to read certificate\n"); + fprintf(stderr, "%s: failed to read certificate\n", __FUNCTION__); return -1; } @@ -727,8 +590,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por status = instance->VerifyX509Certificate(instance, pemCert, length, hostname, port, 0); } - fprintf(stderr, "VerifyX509Certificate: (length = %d) status: %d\n%s\n", - length, status, pemCert); + fprintf(stderr, "%s: (length = %d) status: %d\n%s\n", __FUNCTION__, length, status, pemCert); free(pemCert); BIO_free(bio); @@ -932,57 +794,53 @@ rdpTls* tls_new(rdpSettings* settings) { rdpTls* tls; - tls = (rdpTls*) malloc(sizeof(rdpTls)); + tls = (rdpTls *)calloc(1, sizeof(rdpTls)); + if (!tls) + return NULL; - if (tls) - { - ZeroMemory(tls, sizeof(rdpTls)); + SSL_load_error_strings(); + SSL_library_init(); - SSL_load_error_strings(); - SSL_library_init(); - - tls->settings = settings; - tls->certificate_store = certificate_store_new(settings); - - tls->alertLevel = TLS_ALERT_LEVEL_WARNING; - tls->alertDescription = TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY; - } + tls->settings = settings; + tls->certificate_store = certificate_store_new(settings); + if (!tls->certificate_store) + goto out_free; + tls->alertLevel = TLS_ALERT_LEVEL_WARNING; + tls->alertDescription = TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY; return tls; + +out_free: + free(tls); + return NULL; } void tls_free(rdpTls* tls) { - if (tls) + if (!tls) + return; + + if (tls->ctx) { - if (tls->ssl) - { - SSL_free(tls->ssl); - tls->ssl = NULL; - } - - if (tls->ctx) - { - SSL_CTX_free(tls->ctx); - tls->ctx = NULL; - } - - if (tls->PublicKey) - { - free(tls->PublicKey); - tls->PublicKey = NULL; - } - - if (tls->Bindings) - { - free(tls->Bindings->Bindings); - free(tls->Bindings); - tls->Bindings = NULL; - } - - certificate_store_free(tls->certificate_store); - tls->certificate_store = NULL; - - free(tls); + SSL_CTX_free(tls->ctx); + tls->ctx = NULL; } + + if (tls->PublicKey) + { + free(tls->PublicKey); + tls->PublicKey = NULL; + } + + if (tls->Bindings) + { + free(tls->Bindings->Bindings); + free(tls->Bindings); + tls->Bindings = NULL; + } + + certificate_store_free(tls->certificate_store); + tls->certificate_store = NULL; + + free(tls); } diff --git a/libfreerdp/utils/CMakeLists.txt b/libfreerdp/utils/CMakeLists.txt index 716e96384..6e5858672 100644 --- a/libfreerdp/utils/CMakeLists.txt +++ b/libfreerdp/utils/CMakeLists.txt @@ -25,6 +25,7 @@ set(${MODULE_PREFIX}_SRCS pcap.c profiler.c rail.c + ringbuffer.c signal.c stopwatch.c svc_plugin.c @@ -68,3 +69,9 @@ else() endif() set_property(TARGET ${MODULE_NAME} PROPERTY FOLDER "FreeRDP/libfreerdp") + + +if(BUILD_TESTING) + add_subdirectory(test) +endif() + diff --git a/libfreerdp/utils/ringbuffer.c b/libfreerdp/utils/ringbuffer.c new file mode 100644 index 000000000..a1f14ac5b --- /dev/null +++ b/libfreerdp/utils/ringbuffer.c @@ -0,0 +1,251 @@ +/** + * FreeRDP: A Remote Desktop Protocol Implementation + * + * Copyright 2014 Thincast Technologies GmbH + * Copyright 2014 Hardening + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include + + +BOOL ringbuffer_init(RingBuffer *rb, size_t initialSize) +{ + rb->buffer = malloc(initialSize); + if (!rb->buffer) + return FALSE; + + rb->readPtr = rb->writePtr = 0; + rb->initialSize = rb->size = rb->freeSize = initialSize; + return TRUE; +} + + +size_t ringbuffer_used(const RingBuffer *ringbuffer) +{ + return ringbuffer->size - ringbuffer->freeSize; +} + +size_t ringbuffer_capacity(const RingBuffer *ringbuffer) +{ + return ringbuffer->size; +} + +void ringbuffer_destroy(RingBuffer *ringbuffer) +{ + free(ringbuffer->buffer); + ringbuffer->buffer = 0; +} + +static BOOL ringbuffer_realloc(RingBuffer *rb, size_t targetSize) +{ + BYTE *newData; + + if (rb->writePtr == rb->readPtr) + { + /* when no size is used we can realloc() and set the heads at the + * beginning of the buffer + */ + newData = (BYTE *)realloc(rb->buffer, targetSize); + if (!newData) + return FALSE; + rb->readPtr = rb->writePtr = 0; + rb->buffer = newData; + } + else if ((rb->writePtr >= rb->readPtr) && (rb->writePtr < targetSize)) + { + /* we reallocate only if we're in that case, realloc don't touch read + * and write heads + * + * readPtr writePtr + * | | + * v v + * [............|XXXXXXXXXXXXXX|..........] + */ + newData = (BYTE *)realloc(rb->buffer, targetSize); + if (!newData) + return FALSE; + + rb->buffer = newData; + } + else + { + /* in case of malloc the read head is moved at the beginning of the new buffer + * and the write head is set accordingly + */ + newData = (BYTE *)malloc(targetSize); + if (!newData) + return FALSE; + if (rb->readPtr < rb->writePtr) + { + /* readPtr writePtr + * | | + * v v + * [............|XXXXXXXXXXXXXX|..........] + */ + memcpy(newData, rb->buffer + rb->readPtr, ringbuffer_used(rb)); + } + else + { + /* writePtr readPtr + * | | + * v v + * [XXXXXXXXXXXX|..............|XXXXXXXXXX] + */ + BYTE *dst = newData; + memcpy(dst, rb->buffer + rb->readPtr, rb->size - rb->readPtr); + dst += (rb->size - rb->readPtr); + if (rb->writePtr) + memcpy(dst, rb->buffer, rb->writePtr); + } + rb->writePtr = rb->size - rb->freeSize; + rb->readPtr = 0; + free(rb->buffer); + rb->buffer = newData; + } + + rb->freeSize += (targetSize - rb->size); + rb->size = targetSize; + return TRUE; +} + +/** + * + * @param rb + * @param ptr + * @param sz + * @return + */ +BOOL ringbuffer_write(RingBuffer *rb, const BYTE *ptr, size_t sz) +{ + size_t toWrite; + size_t remaining; + + if ((rb->freeSize <= sz) && !ringbuffer_realloc(rb, rb->size + sz)) + return FALSE; + + /* the write could be split in two + * readHead writeHead + * | | + * v v + * [ ################ ] + */ + toWrite = sz; + remaining = sz; + if (rb->size - rb->writePtr < sz) + toWrite = rb->size - rb->writePtr; + + if (toWrite) + { + memcpy(rb->buffer + rb->writePtr, ptr, toWrite); + remaining -= toWrite; + ptr += toWrite; + } + + if (remaining) + memcpy(rb->buffer, ptr, remaining); + + rb->writePtr = (rb->writePtr + sz) % rb->size; + + rb->freeSize -= sz; + return TRUE; +} + + +BYTE *ringbuffer_ensure_linear_write(RingBuffer *rb, size_t sz) +{ + if (rb->freeSize < sz) + { + if (!ringbuffer_realloc(rb, rb->size + sz - rb->freeSize + 32)) + return NULL; + } + + if (rb->writePtr == rb->readPtr) + { + rb->writePtr = rb->readPtr = 0; + } + + if (rb->writePtr + sz < rb->size) + return rb->buffer + rb->writePtr; + + /* + * to add: ....... + * [ XXXXXXXXX ] + * + * result: + * [XXXXXXXXX....... ] + */ + memmove(rb->buffer, rb->buffer + rb->readPtr, rb->writePtr - rb->readPtr); + rb->readPtr = 0; + rb->writePtr = rb->size - rb->freeSize; + return rb->buffer + rb->writePtr; +} + +BOOL ringbuffer_commit_written_bytes(RingBuffer *rb, size_t sz) +{ + if (rb->writePtr + sz > rb->size) + return FALSE; + rb->writePtr = (rb->writePtr + sz) % rb->size; + rb->freeSize -= sz; + return TRUE; +} + +int ringbuffer_peek(const RingBuffer *rb, DataChunk chunks[2], size_t sz) +{ + size_t remaining = sz; + size_t toRead; + int chunkIndex = 0; + int ret = 0; + + if (rb->size - rb->freeSize < sz) + remaining = rb->size - rb->freeSize; + + toRead = remaining; + + if (rb->readPtr + remaining > rb->size) + toRead = rb->size - rb->readPtr; + + if (toRead) + { + chunks[0].data = rb->buffer + rb->readPtr; + chunks[0].size = toRead; + remaining -= toRead; + chunkIndex++; + ret++; + } + + if (remaining) + { + chunks[chunkIndex].data = rb->buffer; + chunks[chunkIndex].size = remaining; + ret++; + } + return ret; +} + +void ringbuffer_commit_read_bytes(RingBuffer *rb, size_t sz) +{ + assert(rb->size - rb->freeSize >= sz); + + rb->readPtr = (rb->readPtr + sz) % rb->size; + rb->freeSize += sz; + + /* when we reach a reasonable free size, we can go back to the original size */ + if ((rb->size != rb->initialSize) && (ringbuffer_used(rb) < rb->initialSize / 2)) + ringbuffer_realloc(rb, rb->initialSize); +} diff --git a/libfreerdp/utils/test/.gitignore b/libfreerdp/utils/test/.gitignore new file mode 100644 index 000000000..0e7faad57 --- /dev/null +++ b/libfreerdp/utils/test/.gitignore @@ -0,0 +1 @@ +TestFreeRDPutils.c diff --git a/libfreerdp/utils/test/CMakeLists.txt b/libfreerdp/utils/test/CMakeLists.txt new file mode 100644 index 000000000..e6ab6134c --- /dev/null +++ b/libfreerdp/utils/test/CMakeLists.txt @@ -0,0 +1,36 @@ + +set(MODULE_NAME "TestFreeRDPUtils") +set(MODULE_PREFIX "TEST_FREERDP_UTILS") + +set(${MODULE_PREFIX}_DRIVER ${MODULE_NAME}.c) + +set(${MODULE_PREFIX}_TESTS + TestRingBuffer.c) + +create_test_sourcelist(${MODULE_PREFIX}_SRCS + ${${MODULE_PREFIX}_DRIVER} + ${${MODULE_PREFIX}_TESTS}) + +add_executable(${MODULE_NAME} ${${MODULE_PREFIX}_SRCS}) + +set_complex_link_libraries(VARIABLE ${MODULE_PREFIX}_LIBS + MONOLITHIC ${MONOLITHIC_BUILD} + MODULE winpr + MODULES winpr-thread winpr-synch winpr-file winpr-utils winpr-crt) + +set_complex_link_libraries(VARIABLE ${MODULE_PREFIX}_LIBS + MONOLITHIC ${MONOLITHIC_BUILD} + MODULE freerdp + MODULES freerdp-utils) + +target_link_libraries(${MODULE_NAME} ${${MODULE_PREFIX}_LIBS}) + +set_target_properties(${MODULE_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${TESTING_OUTPUT_DIRECTORY}") + +foreach(test ${${MODULE_PREFIX}_TESTS}) + get_filename_component(TestName ${test} NAME_WE) + add_test(${TestName} ${TESTING_OUTPUT_DIRECTORY}/${MODULE_NAME} ${TestName}) +endforeach() + +set_property(TARGET ${MODULE_NAME} PROPERTY FOLDER "FreeRDP/Test") + diff --git a/libfreerdp/utils/test/TestRingBuffer.c b/libfreerdp/utils/test/TestRingBuffer.c new file mode 100644 index 000000000..1f4e3f504 --- /dev/null +++ b/libfreerdp/utils/test/TestRingBuffer.c @@ -0,0 +1,228 @@ +/** + * FreeRDP: A Remote Desktop Protocol Implementation + * + * Copyright 2014 Thincast Technologies GmbH + * Copyright 2014 Hardening + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +BOOL test_overlaps(void) +{ + RingBuffer rb; + DataChunk chunks[2]; + BYTE bytes[200]; + int nchunks, i, j, k, counter = 0; + + for (i = 0; i < sizeof(bytes); i++) + bytes[i] = (BYTE)i; + + ringbuffer_init(&rb, 5); + if (!ringbuffer_write(&rb, bytes, 4)) /* [0123.] */ + goto error; + counter += 4; + ringbuffer_commit_read_bytes(&rb, 2); /* [..23.] */ + + if (!ringbuffer_write(&rb, &bytes[counter], 2)) /* [5.234] */ + goto error; + counter += 2; + + nchunks = ringbuffer_peek(&rb, chunks, 4); + if (nchunks != 2 || chunks[0].size != 3 || chunks[1].size != 1) + goto error; + + for (i = 0, j = 2; i < nchunks; i++) + { + for (k = 0; k < (int) chunks[i].size; k++, j++) + { + if (chunks[i].data[k] != (BYTE)j) + goto error; + } + } + + ringbuffer_commit_read_bytes(&rb, 3); /* [5....] */ + if (ringbuffer_used(&rb) != 1) + goto error; + + if (!ringbuffer_write(&rb, &bytes[counter], 6)) /* [56789ab....] */ + goto error; + counter += 6; + + ringbuffer_commit_read_bytes(&rb, 6); /* [......b....] */ + nchunks = ringbuffer_peek(&rb, chunks, 10); + if (nchunks != 1 || chunks[0].size != 1 || (*chunks[0].data != 0xb)) + goto error; + + if (ringbuffer_capacity(&rb) != 5) + goto error; + + ringbuffer_destroy(&rb); + return TRUE; +error: + ringbuffer_destroy(&rb); + return FALSE; +} + + +int TestRingBuffer(int argc, char* argv[]) +{ + RingBuffer ringBuffer; + int testNo = 0; + BYTE *tmpBuf; + BYTE *rb_ptr; + int i/*, chunkNb, counter*/; + DataChunk chunks[2]; + + if (!ringbuffer_init(&ringBuffer, 10)) + { + fprintf(stderr, "unable to initialize ringbuffer\n"); + return -1; + } + + tmpBuf = (BYTE *)malloc(50); + if (!tmpBuf) + return -1; + + for (i = 0; i < 50; i++) + tmpBuf[i] = (char)i; + + fprintf(stderr, "%d: basic tests...", ++testNo); + if (!ringbuffer_write(&ringBuffer, tmpBuf, 5) || !ringbuffer_write(&ringBuffer, tmpBuf, 5) || + !ringbuffer_write(&ringBuffer, tmpBuf, 5)) + { + fprintf(stderr, "error when writing bytes\n"); + return -1; + } + + if (ringbuffer_used(&ringBuffer) != 15) + { + fprintf(stderr, "invalid used size got %d when i would expect 15\n", ringbuffer_used(&ringBuffer)); + return -1; + } + + if (ringbuffer_peek(&ringBuffer, chunks, 10) != 1 || chunks[0].size != 10) + { + fprintf(stderr, "error when reading bytes\n"); + return -1; + } + ringbuffer_commit_read_bytes(&ringBuffer, chunks[0].size); + + /* check retrieved bytes */ + for (i = 0; i < (int) chunks[0].size; i++) + { + if (chunks[0].data[i] != i % 5) + { + fprintf(stderr, "invalid byte at %d, got %d instead of %d\n", i, chunks[0].data[i], i % 5); + return -1; + } + } + + if (ringbuffer_used(&ringBuffer) != 5) + { + fprintf(stderr, "invalid used size after read got %d when i would expect 5\n", ringbuffer_used(&ringBuffer)); + return -1; + } + + /* write some more bytes to have writePtr < readPtr and data splitted in 2 chunks */ + if (!ringbuffer_write(&ringBuffer, tmpBuf, 6) || + ringbuffer_peek(&ringBuffer, chunks, 11) != 2 || + chunks[0].size != 10 || + chunks[1].size != 1) + { + fprintf(stderr, "invalid read of splitted data\n"); + return -1; + } + + ringbuffer_commit_read_bytes(&ringBuffer, 11); + fprintf(stderr, "ok\n"); + + fprintf(stderr, "%d: peek with nothing to read...", ++testNo); + if (ringbuffer_peek(&ringBuffer, chunks, 10)) + { + fprintf(stderr, "peek returns some chunks\n"); + return -1; + } + fprintf(stderr, "ok\n"); + + fprintf(stderr, "%d: ensure_linear_write / read() shouldn't grow...", ++testNo); + for (i = 0; i < 1000; i++) + { + rb_ptr = ringbuffer_ensure_linear_write(&ringBuffer, 50); + if (!rb_ptr) + { + fprintf(stderr, "ringbuffer_ensure_linear_write() error\n"); + return -1; + } + + memcpy(rb_ptr, tmpBuf, 50); + + if (!ringbuffer_commit_written_bytes(&ringBuffer, 50)) + { + fprintf(stderr, "ringbuffer_commit_written_bytes() error, i=%d\n", i); + return -1; + } + + //ringbuffer_commit_read_bytes(&ringBuffer, 25); + } + + for (i = 0; i < 1000; i++) + ringbuffer_commit_read_bytes(&ringBuffer, 25); + + for (i = 0; i < 1000; i++) + ringbuffer_commit_read_bytes(&ringBuffer, 25); + + + if (ringbuffer_capacity(&ringBuffer) != 10) + { + fprintf(stderr, "not the expected capacity, have %d and expects 10\n", ringbuffer_capacity(&ringBuffer)); + return -1; + } + fprintf(stderr, "ok\n"); + + + fprintf(stderr, "%d: free size is correctly computed...", ++testNo); + for (i = 0; i < 1000; i++) + { + ringbuffer_ensure_linear_write(&ringBuffer, 50); + if (!ringbuffer_commit_written_bytes(&ringBuffer, 50)) + { + fprintf(stderr, "ringbuffer_commit_written_bytes() error, i=%d\n", i); + return -1; + } + } + ringbuffer_commit_read_bytes(&ringBuffer, 50 * 1000); + fprintf(stderr, "ok\n"); + + ringbuffer_destroy(&ringBuffer); + + fprintf(stderr, "%d: specific overlaps test...", ++testNo); + if (!test_overlaps()) + { + fprintf(stderr, "ko\n", i); + return -1; + } + fprintf(stderr, "ok\n"); + + ringbuffer_destroy(&ringBuffer); + free(tmpBuf); + return 0; +} + + + + diff --git a/server/CMakeLists.txt b/server/CMakeLists.txt index e51b2c82a..a361a8b63 100644 --- a/server/CMakeLists.txt +++ b/server/CMakeLists.txt @@ -19,27 +19,29 @@ add_subdirectory(common) -if(WITH_SAMPLE) - add_subdirectory(Sample) -endif() - -if(NOT WIN32) - if(WITH_X11) - add_subdirectory(X11) +if(FREERDP_VENDOR) + if(WITH_SAMPLE) + add_subdirectory(Sample) endif() - if(APPLE AND (NOT IOS)) - add_subdirectory(Mac) + if(NOT WIN32) + if(WITH_X11) + add_subdirectory(X11) + endif() + + if(APPLE AND (NOT IOS)) + add_subdirectory(Mac) + endif() + else() + add_subdirectory(Windows) + endif() + + if(IS_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/FreeRDS") + add_subdirectory("FreeRDS") endif() -else() - add_subdirectory(Windows) endif() -if(IS_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/FreeRDS") - add_subdirectory("FreeRDS") -endif() - -# Pick up other clients +# Pick up other servers set(FILENAME "ModuleOptions.cmake") file(GLOB FILEPATHS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*/${FILENAME}") diff --git a/winpr/include/winpr/bitstream.h b/winpr/include/winpr/bitstream.h index f782aa23f..b77a1dc97 100644 --- a/winpr/include/winpr/bitstream.h +++ b/winpr/include/winpr/bitstream.h @@ -48,37 +48,37 @@ extern "C" { #define BitStream_Prefetch(_bs) do { \ (_bs->prefetch) = 0; \ - if ((_bs->pointer - _bs->buffer) < (_bs->capacity + 4)) \ + if (((UINT32) (_bs->pointer - _bs->buffer)) < (_bs->capacity + 4)) \ (_bs->prefetch) |= (*(_bs->pointer + 4) << 24); \ - if ((_bs->pointer - _bs->buffer) < (_bs->capacity + 5)) \ + if (((UINT32) (_bs->pointer - _bs->buffer)) < (_bs->capacity + 5)) \ (_bs->prefetch) |= (*(_bs->pointer + 5) << 16); \ - if ((_bs->pointer - _bs->buffer) < (_bs->capacity + 6)) \ + if (((UINT32) (_bs->pointer - _bs->buffer)) < (_bs->capacity + 6)) \ (_bs->prefetch) |= (*(_bs->pointer + 6) << 8); \ - if ((_bs->pointer - _bs->buffer) < (_bs->capacity + 7)) \ + if (((UINT32) (_bs->pointer - _bs->buffer)) < (_bs->capacity + 7)) \ (_bs->prefetch) |= (*(_bs->pointer + 7) << 0); \ } while(0) #define BitStream_Fetch(_bs) do { \ (_bs->accumulator) = 0; \ - if ((_bs->pointer - _bs->buffer) < (_bs->capacity + 0)) \ + if (((UINT32) (_bs->pointer - _bs->buffer)) < (_bs->capacity + 0)) \ (_bs->accumulator) |= (*(_bs->pointer + 0) << 24); \ - if ((_bs->pointer - _bs->buffer) < (_bs->capacity + 1)) \ + if (((UINT32) (_bs->pointer - _bs->buffer)) < (_bs->capacity + 1)) \ (_bs->accumulator) |= (*(_bs->pointer + 1) << 16); \ - if ((_bs->pointer - _bs->buffer) < (_bs->capacity + 2)) \ + if (((UINT32) (_bs->pointer - _bs->buffer)) < (_bs->capacity + 2)) \ (_bs->accumulator) |= (*(_bs->pointer + 2) << 8); \ - if ((_bs->pointer - _bs->buffer) < (_bs->capacity + 3)) \ + if (((UINT32) (_bs->pointer - _bs->buffer)) <(_bs->capacity + 3)) \ (_bs->accumulator) |= (*(_bs->pointer + 3) << 0); \ BitStream_Prefetch(_bs); \ } while(0) #define BitStream_Flush(_bs) do { \ - if ((_bs->pointer - _bs->buffer) < (_bs->capacity + 0)) \ + if (((UINT32) (_bs->pointer - _bs->buffer)) < (_bs->capacity + 0)) \ *(_bs->pointer + 0) = (_bs->accumulator >> 24); \ - if ((_bs->pointer - _bs->buffer) < (_bs->capacity + 1)) \ + if (((UINT32) (_bs->pointer - _bs->buffer)) < (_bs->capacity + 1)) \ *(_bs->pointer + 1) = (_bs->accumulator >> 16); \ - if ((_bs->pointer - _bs->buffer) < (_bs->capacity + 2)) \ + if (((UINT32) (_bs->pointer - _bs->buffer)) < (_bs->capacity + 2)) \ *(_bs->pointer + 2) = (_bs->accumulator >> 8); \ - if ((_bs->pointer - _bs->buffer) < (_bs->capacity + 3)) \ + if (((UINT32) (_bs->pointer - _bs->buffer)) < (_bs->capacity + 3)) \ *(_bs->pointer + 3) = (_bs->accumulator >> 0); \ } while(0) diff --git a/winpr/include/winpr/error.h b/winpr/include/winpr/error.h index 6f59e4884..e84b8f7bc 100644 --- a/winpr/include/winpr/error.h +++ b/winpr/include/winpr/error.h @@ -29,10 +29,6 @@ #else -#ifdef __cplusplus -extern "C" { -#endif - #ifndef NO_ERROR #define NO_ERROR 0 #endif @@ -3003,6 +2999,10 @@ typedef PTOP_LEVEL_EXCEPTION_FILTER LPTOP_LEVEL_EXCEPTION_FILTER; typedef LONG (*PVECTORED_EXCEPTION_HANDLER)(PEXCEPTION_POINTERS ExceptionInfo); +#ifdef __cplusplus +extern "C" { +#endif + WINPR_API UINT GetErrorMode(void); WINPR_API UINT SetErrorMode(UINT uMode); diff --git a/winpr/include/winpr/locale.h b/winpr/include/winpr/locale.h index b63b421bd..d6bcabfe6 100644 --- a/winpr/include/winpr/locale.h +++ b/winpr/include/winpr/locale.h @@ -25,6 +25,9 @@ #ifndef _WIN32 +#include +#include + #define LANG_NEUTRAL 0x00 #define LANG_INVARIANT 0x7f @@ -483,7 +486,17 @@ extern "C" { #endif +DWORD WINAPI FormatMessageA(DWORD dwFlags, LPCVOID lpSource, DWORD dwMessageId, DWORD dwLanguageId, + LPSTR lpBuffer, DWORD nSize, va_list* Arguments); +DWORD WINAPI FormatMessageW(DWORD dwFlags, LPCVOID lpSource, DWORD dwMessageId, DWORD dwLanguageId, + LPWSTR lpBuffer, DWORD nSize, va_list* Arguments); + +#ifdef UNICODE +#define FormatMessage FormatMessageW +#else +#define FormatMessage FormatMessageA +#endif #ifdef __cplusplus } diff --git a/winpr/libwinpr/com/com.c b/winpr/libwinpr/com/com.c index 97332a070..367f12b6f 100644 --- a/winpr/libwinpr/com/com.c +++ b/winpr/libwinpr/com/com.c @@ -22,6 +22,7 @@ #endif #include +#include /** * api-ms-win-core-com-l1-1-0.dll: @@ -110,6 +111,14 @@ #ifndef _WIN32 +HRESULT CoInitializeEx(LPVOID pvReserved, DWORD dwCoInit) +{ + return S_OK; +} +void CoUninitialize(void) +{ + +} #endif diff --git a/winpr/libwinpr/handle/CMakeLists.txt b/winpr/libwinpr/handle/CMakeLists.txt index 3dc9dfb0c..93d112bcd 100644 --- a/winpr/libwinpr/handle/CMakeLists.txt +++ b/winpr/libwinpr/handle/CMakeLists.txt @@ -20,8 +20,7 @@ set(MODULE_PREFIX "WINPR_HANDLE") set(${MODULE_PREFIX}_SRCS handle.c - handle.h - table.c) + handle.h) if(MSVC AND (NOT MONOLITHIC_BUILD)) set(${MODULE_PREFIX}_SRCS ${${MODULE_PREFIX}_SRCS} module.def) diff --git a/winpr/libwinpr/handle/table.c b/winpr/libwinpr/handle/table.c deleted file mode 100644 index eb81e2b0c..000000000 --- a/winpr/libwinpr/handle/table.c +++ /dev/null @@ -1,32 +0,0 @@ -/** - * WinPR: Windows Portable Runtime - * Handle Management - * - * Copyright 2012 Marc-Andre Moreau - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifdef HAVE_CONFIG_H -#include "config.h" -#endif - -#include -#include - -#ifndef _WIN32 - -#include "../handle/handle.h" - -#endif - diff --git a/winpr/libwinpr/locale/locale.c b/winpr/libwinpr/locale/locale.c index 439401128..177484b1d 100644 --- a/winpr/libwinpr/locale/locale.c +++ b/winpr/libwinpr/locale/locale.c @@ -90,7 +90,17 @@ #ifndef _WIN32 +DWORD WINAPI FormatMessageA(DWORD dwFlags, LPCVOID lpSource, DWORD dwMessageId, DWORD dwLanguageId, + LPSTR lpBuffer, DWORD nSize, va_list* Arguments) +{ + return 0; +} +DWORD WINAPI FormatMessageW(DWORD dwFlags, LPCVOID lpSource, DWORD dwMessageId, DWORD dwLanguageId, + LPWSTR lpBuffer, DWORD nSize, va_list* Arguments) +{ + return 0; +} #endif diff --git a/winpr/libwinpr/registry/registry_reg.c b/winpr/libwinpr/registry/registry_reg.c index e7122c04f..dafee6ca1 100644 --- a/winpr/libwinpr/registry/registry_reg.c +++ b/winpr/libwinpr/registry/registry_reg.c @@ -425,7 +425,7 @@ void reg_print_value(Reg* reg, RegVal* value) if (value->type == REG_DWORD) { - fprintf(stderr, "dword:%08X\n", value->data.dword); + fprintf(stderr, "dword:%08X\n", (int) value->data.dword); } else if (value->type == REG_SZ) { diff --git a/winpr/libwinpr/smartcard/CMakeLists.txt b/winpr/libwinpr/smartcard/CMakeLists.txt index 4ff7bfd0e..831dc271e 100644 --- a/winpr/libwinpr/smartcard/CMakeLists.txt +++ b/winpr/libwinpr/smartcard/CMakeLists.txt @@ -27,9 +27,13 @@ set(${MODULE_PREFIX}_SRCS smartcard.h smartcard_link.c smartcard_pcsc.c - smartcard_pcsc.h - smartcard_winscard.c - smartcard_winscard.h) + smartcard_pcsc.h) + +if(WIN32) + list(APPEND ${MODULE_PREFIX}_SRCS + smartcard_winscard.c + smartcard_winscard.h) +endif() if(MSVC AND (NOT MONOLITHIC_BUILD)) set(${MODULE_PREFIX}_SRCS ${${MODULE_PREFIX}_SRCS} module.def) @@ -45,7 +49,7 @@ set_target_properties(${MODULE_NAME} PROPERTIES VERSION ${WINPR_VERSION_FULL} SO set_complex_link_libraries(VARIABLE ${MODULE_PREFIX}_LIBS MONOLITHIC ${MONOLITHIC_BUILD} INTERNAL MODULE winpr - MODULES winpr-crt winpr-library winpr-utils) + MODULES winpr-crt winpr-library winpr-environment winpr-utils) if(PCSC_WINPR_FOUND) list(APPEND ${MODULE_PREFIX}_LIBS ${PCSC_WINPR_LIBRARY}) diff --git a/winpr/libwinpr/smartcard/smartcard_pcsc.c b/winpr/libwinpr/smartcard/smartcard_pcsc.c index 7411c4f8f..5ba47a958 100644 --- a/winpr/libwinpr/smartcard/smartcard_pcsc.c +++ b/winpr/libwinpr/smartcard/smartcard_pcsc.c @@ -33,9 +33,72 @@ #include #include #include +#include #include "smartcard_pcsc.h" +/** + * PC/SC transactions: + * http://developersblog.wwpass.com/?p=180 + */ + +/** + * Smart Card Logon on Windows Vista: + * http://blogs.msdn.com/b/shivaram/archive/2007/02/26/smart-card-logon-on-windows-vista.aspx + */ + +/** + * The Smart Card Cryptographic Service Provider Cookbook: + * http://msdn.microsoft.com/en-us/library/ms953432.aspx + * + * SCARDCONTEXT + * + * The context is a communication channel with the smart card resource manager and + * all calls to the resource manager must go through this link. + * + * All functions that take a context as a parameter or a card handle as parameter, + * which is indirectly associated with a particular context, may be blocking calls. + * Examples of these are SCardGetStatusChange and SCardBeginTransaction, which takes + * a card handle as a parameter. If such a function blocks then all operations wanting + * to use the context are blocked as well. So, it is recommended that a CSP using + * monitoring establishes at least two contexts with the resource manager; one for + * monitoring (with SCardGetStatusChange) and one for other operations. + * + * If multiple cards are present, it is recommended that a separate context or pair + * of contexts be established for each card to prevent operations on one card from + * blocking operations on another. + * + * Example one + * + * The example below shows what can happen if a CSP using SCardGetStatusChange for + * monitoring does not establish two contexts with the resource manager. + * The context becomes unusable until SCardGetStatusChange unblocks. + * + * In this example, there is one process running called P1. + * P1 calls SCardEstablishContext, which returns the context hCtx. + * P1 calls SCardConnect (with the hCtx context) which returns a handle to the card, hCard. + * P1 calls SCardGetStatusChange (with the hCtx context) which blocks because + * there are no status changes to report. + * Until the thread running SCardGetStatusChange unblocks, another thread in P1 trying to + * perform an operation using the context hCtx (or the card hCard) will also be blocked. + * + * Example two + * + * The example below shows how transaction control ensures that operations meant to be + * performed without interruption can do so safely within a transaction. + * + * In this example, there are two different processes running; P1 and P2. + * P1 calls SCardEstablishContext, which returns the context hCtx1. + * P2 calls SCardEstablishContext, which returns the context hCtx2. + * P1 calls SCardConnect (with the hCtx1 context) which returns a handle to the card, hCard1. + * P2 calls SCardConnect (with the hCtx2 context) which returns a handle to the same card, hCard2. + * P1 calls SCardBeginTransaction (with the hCard 1 context). + * Until P1 calls SCardEndTransaction (with the hCard1 context), + * any operation using hCard2 will be blocked. + * Once an operation using hCard2 is blocked and until it's returning, + * any operation using hCtx2 (and hCard2) will also be blocked. + */ + //#define DISABLE_PCSC_SCARD_AUTOALLOCATE struct _PCSC_SCARDCONTEXT @@ -46,6 +109,13 @@ struct _PCSC_SCARDCONTEXT }; typedef struct _PCSC_SCARDCONTEXT PCSC_SCARDCONTEXT; +struct _PCSC_SCARDHANDLE +{ + SCARDCONTEXT hContext; + CRITICAL_SECTION lock; +}; +typedef struct _PCSC_SCARDHANDLE PCSC_SCARDHANDLE; + struct _PCSC_READER { char* namePCSC; @@ -56,6 +126,9 @@ typedef struct _PCSC_READER PCSC_READER; static HMODULE g_PCSCModule = NULL; static PCSCFunctionTable g_PCSC = { 0 }; +static HANDLE g_StartedEvent = NULL; +static int g_StartedEventRefCount = 0; + static BOOL g_SCardAutoAllocate = FALSE; static BOOL g_PnP_Notification = TRUE; @@ -196,12 +269,12 @@ PCSC_SCARDCONTEXT* PCSC_GetCardContextData(SCARDCONTEXT hContext) PCSC_SCARDCONTEXT* pContext; if (!g_CardContexts) - return 0; + return NULL; pContext = (PCSC_SCARDCONTEXT*) ListDictionary_GetItemValue(g_CardContexts, (void*) hContext); if (!pContext) - return 0; + return NULL; return pContext; } @@ -286,6 +359,145 @@ BOOL PCSC_UnlockCardContext(SCARDCONTEXT hContext) return TRUE; } +PCSC_SCARDHANDLE* PCSC_GetCardHandleData(SCARDHANDLE hCard) +{ + PCSC_SCARDHANDLE* pCard; + + if (!g_CardHandles) + return NULL; + + pCard = (PCSC_SCARDHANDLE*) ListDictionary_GetItemValue(g_CardHandles, (void*) hCard); + + if (!pCard) + return NULL; + + return pCard; +} + +SCARDCONTEXT PCSC_GetCardContextFromHandle(SCARDHANDLE hCard) +{ + PCSC_SCARDHANDLE* pCard; + + pCard = PCSC_GetCardHandleData(hCard); + + if (!pCard) + return 0; + + return pCard->hContext; +} + +PCSC_SCARDHANDLE* PCSC_ConnectCardHandle(SCARDCONTEXT hContext, SCARDHANDLE hCard) +{ + PCSC_SCARDHANDLE* pCard; + + pCard = (PCSC_SCARDHANDLE*) calloc(1, sizeof(PCSC_SCARDHANDLE)); + + if (!pCard) + return NULL; + + pCard->hContext = hContext; + + InitializeCriticalSectionAndSpinCount(&(pCard->lock), 4000); + + if (!g_CardHandles) + g_CardHandles = ListDictionary_New(TRUE); + + ListDictionary_Add(g_CardHandles, (void*) hCard, (void*) pCard); + + return pCard; +} + +void PCSC_DisconnectCardHandle(SCARDHANDLE hCard) +{ + PCSC_SCARDHANDLE* pCard; + + pCard = PCSC_GetCardHandleData(hCard); + + if (!pCard) + return; + + DeleteCriticalSection(&(pCard->lock)); + + free(pCard); + + if (!g_CardHandles) + return; + + ListDictionary_Remove(g_CardHandles, (void*) hCard); +} + +BOOL PCSC_LockCardHandle(SCARDHANDLE hCard) +{ + PCSC_SCARDHANDLE* pCard; + + pCard = PCSC_GetCardHandleData(hCard); + + if (!pCard) + { + fprintf(stderr, "PCSC_LockCardHandle: invalid handle (%p)\n", (void*) hCard); + return FALSE; + } + + EnterCriticalSection(&(pCard->lock)); + + return TRUE; +} + +BOOL PCSC_UnlockCardHandle(SCARDHANDLE hCard) +{ + PCSC_SCARDHANDLE* pCard; + + pCard = PCSC_GetCardHandleData(hCard); + + if (!pCard) + { + fprintf(stderr, "PCSC_UnlockCardHandle: invalid handle (%p)\n", (void*) hCard); + return FALSE; + } + + LeaveCriticalSection(&(pCard->lock)); + + return TRUE; +} + +BOOL PCSC_LockCardTransaction(SCARDHANDLE hCard) +{ + PCSC_SCARDHANDLE* pCard; + + return TRUE; /* disable for now because it deadlocks */ + + pCard = PCSC_GetCardHandleData(hCard); + + if (!pCard) + { + fprintf(stderr, "PCSC_LockCardTransaction: invalid handle (%p)\n", (void*) hCard); + return FALSE; + } + + EnterCriticalSection(&(pCard->lock)); + + return TRUE; +} + +BOOL PCSC_UnlockCardTransaction(SCARDHANDLE hCard) +{ + PCSC_SCARDHANDLE* pCard; + + return TRUE; /* disable for now because it deadlocks */ + + pCard = PCSC_GetCardHandleData(hCard); + + if (!pCard) + { + fprintf(stderr, "PCSC_UnlockCardTransaction: invalid handle (%p)\n", (void*) hCard); + return FALSE; + } + + LeaveCriticalSection(&(pCard->lock)); + + return TRUE; +} + char* PCSC_GetReaderNameFromAlias(char* nameWinSCard) { int index; @@ -601,34 +813,6 @@ char* PCSC_ConvertReaderNamesToPCSC(const char* names, LPDWORD pcchReaders) return namesPCSC; } -void PCSC_AddCardHandle(SCARDCONTEXT hContext, SCARDHANDLE hCard) -{ - if (!g_CardHandles) - g_CardHandles = ListDictionary_New(TRUE); - - ListDictionary_Add(g_CardHandles, (void*) hCard, (void*) hContext); -} - -void* PCSC_RemoveCardHandle(SCARDHANDLE hCard) -{ - if (!g_CardHandles) - return NULL; - - return ListDictionary_Remove(g_CardHandles, (void*) hCard); -} - -SCARDCONTEXT PCSC_GetCardContextFromHandle(SCARDHANDLE hCard) -{ - SCARDCONTEXT hContext; - - if (!g_CardHandles) - return 0; - - hContext = (SCARDCONTEXT) ListDictionary_GetItemValue(g_CardHandles, (void*) hCard); - - return hContext; -} - void PCSC_AddMemoryBlock(SCARDCONTEXT hContext, void* pvMem) { if (!g_MemoryBlocks) @@ -690,6 +874,12 @@ WINSCARDAPI LONG WINAPI PCSC_SCardReleaseContext(SCARDCONTEXT hContext) if (!g_PCSC.pfnSCardReleaseContext) return SCARD_E_NO_SERVICE; + if (!hContext) + { + fprintf(stderr, "SCardReleaseContext: null hContext\n"); + return status; + } + status = (LONG) g_PCSC.pfnSCardReleaseContext(hContext); status = PCSC_MapErrorCodeToWinSCard(status); @@ -808,7 +998,7 @@ WINSCARDAPI LONG WINAPI PCSC_SCardListReaders_Internal(SCARDCONTEXT hContext, } else { - status = (LONG) g_PCSC.pfnSCardListReaders(hContext, NULL, mszReaders, &pcsc_cchReaders); + status = (LONG) g_PCSC.pfnSCardListReaders(hContext, mszGroups, mszReaders, &pcsc_cchReaders); } status = PCSC_MapErrorCodeToWinSCard(status); @@ -1113,12 +1303,42 @@ WINSCARDAPI LONG WINAPI PCSC_SCardFreeMemory(SCARDCONTEXT hContext, LPCVOID pvMe WINSCARDAPI HANDLE WINAPI PCSC_SCardAccessStartedEvent(void) { - return 0; + LONG status = 0; + SCARDCONTEXT hContext = 0; + + status = PCSC_SCardEstablishContext(SCARD_SCOPE_SYSTEM, NULL, NULL, &hContext); + + if (status != SCARD_S_SUCCESS) + return NULL; + + status = PCSC_SCardReleaseContext(hContext); + + if (status != SCARD_S_SUCCESS) + return NULL; + + if (!g_StartedEvent) + { + g_StartedEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + SetEvent(g_StartedEvent); + } + + g_StartedEventRefCount++; + + return g_StartedEvent; } WINSCARDAPI void WINAPI PCSC_SCardReleaseStartedEvent(void) { + g_StartedEventRefCount--; + if (g_StartedEventRefCount == 0) + { + if (g_StartedEvent) + { + CloseHandle(g_StartedEvent); + g_StartedEvent = NULL; + } + } } WINSCARDAPI LONG WINAPI PCSC_SCardLocateCardsA(SCARDCONTEXT hContext, @@ -1161,6 +1381,9 @@ WINSCARDAPI LONG WINAPI PCSC_SCardGetStatusChange_Internal(SCARDCONTEXT hContext if (!g_PCSC.pfnSCardGetStatusChange) return SCARD_E_NO_SERVICE; + if (!cReaders) + return SCARD_S_SUCCESS; + /** * Apple's SmartCard Services (not vanilla pcsc-lite) appears to have trouble with the * "\\\\?PnP?\\Notification" reader name. I am always getting EXC_BAD_ACCESS with it. @@ -1171,7 +1394,7 @@ WINSCARDAPI LONG WINAPI PCSC_SCardGetStatusChange_Internal(SCARDCONTEXT hContext * The "\\\\?PnP?\\Notification" string cannot be found anywhere in the sources, * while this string is present in the vanilla pcsc-lite sources. * - * To work around this apparently lack of "\\\\?PnP?\\Notification" support, + * To work around this apparent lack of "\\\\?PnP?\\Notification" support, * we have to filter rgReaderStates to exclude the special PnP reader name. */ @@ -1301,12 +1524,12 @@ WINSCARDAPI LONG WINAPI PCSC_SCardGetStatusChangeW(SCARDCONTEXT hContext, LPSCARD_READERSTATEA states; LONG status = SCARD_S_SUCCESS; - if (!PCSC_LockCardContext(hContext)) - return SCARD_E_INVALID_HANDLE; - if (!g_PCSC.pfnSCardGetStatusChange) return SCARD_E_NO_SERVICE; + if (!PCSC_LockCardContext(hContext)) + return SCARD_E_INVALID_HANDLE; + states = (LPSCARD_READERSTATEA) calloc(cReaders, sizeof(SCARD_READERSTATEA)); if (!states) @@ -1387,7 +1610,7 @@ WINSCARDAPI LONG WINAPI PCSC_SCardConnect_Internal(SCARDCONTEXT hContext, if (status == SCARD_S_SUCCESS) { - PCSC_AddCardHandle(hContext, *phCard); + PCSC_ConnectCardHandle(hContext, *phCard); *pdwActiveProtocol = PCSC_ConvertProtocolsToWinSCard((DWORD) pcsc_dwActiveProtocol); } @@ -1471,7 +1694,9 @@ WINSCARDAPI LONG WINAPI PCSC_SCardDisconnect(SCARDHANDLE hCard, DWORD dwDisposit status = PCSC_MapErrorCodeToWinSCard(status); if (status == SCARD_S_SUCCESS) - PCSC_RemoveCardHandle(hCard); + { + PCSC_DisconnectCardHandle(hCard); + } return status; } @@ -2422,6 +2647,9 @@ extern int PCSC_InitializeSCardApi_Link(void); int PCSC_InitializeSCardApi(void) { + /* Disable pcsc-lite's (poor) blocking so we can handle it ourselves */ + //SetEnvironmentVariableA("PCSCLITE_NO_BLOCKING", "1"); + #ifndef DISABLE_PCSC_LINK if (PCSC_InitializeSCardApi_Link() >= 0) { diff --git a/winpr/libwinpr/smartcard/smartcard_pcsc.h b/winpr/libwinpr/smartcard/smartcard_pcsc.h index b30d6dd70..a53670da2 100644 --- a/winpr/libwinpr/smartcard/smartcard_pcsc.h +++ b/winpr/libwinpr/smartcard/smartcard_pcsc.h @@ -71,9 +71,9 @@ typedef long PCSC_LONG; #define PCSC_SCARD_AUTOALLOCATE (PCSC_DWORD)(-1) -#define PCSC_SCARD_PCI_T0 (&g_PCSC_rgSCardT0Pci) -#define PCSC_SCARD_PCI_T1 (&g_PCSC_rgSCardT1Pci) -#define PCSC_SCARD_PCI_RAW (&g_PCSC_rgSCardRawPci) +#define PCSC_SCARD_PCI_T0 (&g_PCSC_rgSCardT0Pci) +#define PCSC_SCARD_PCI_T1 (&g_PCSC_rgSCardT1Pci) +#define PCSC_SCARD_PCI_RAW (&g_PCSC_rgSCardRawPci) #define PCSC_SCARD_CTL_CODE(code) (0x42000000 + (code)) #define PCSC_CM_IOCTL_GET_FEATURE_REQUEST SCARD_CTL_CODE(3400) diff --git a/winpr/libwinpr/synch/CMakeLists.txt b/winpr/libwinpr/synch/CMakeLists.txt index b2f40cc05..66ce56a53 100644 --- a/winpr/libwinpr/synch/CMakeLists.txt +++ b/winpr/libwinpr/synch/CMakeLists.txt @@ -35,7 +35,6 @@ set(${MODULE_PREFIX}_SRCS semaphore.c sleep.c srw.c - synch.c synch.h timer.c wait.c) diff --git a/winpr/libwinpr/synch/synch.c b/winpr/libwinpr/synch/synch.c deleted file mode 100644 index d27a73f85..000000000 --- a/winpr/libwinpr/synch/synch.c +++ /dev/null @@ -1,25 +0,0 @@ -/** - * WinPR: Windows Portable Runtime - * Synchronization Functions - * - * Copyright 2012 Marc-Andre Moreau - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifdef HAVE_CONFIG_H -#include "config.h" -#endif - -#include - diff --git a/winpr/libwinpr/synch/test/TestSynchCritical.c b/winpr/libwinpr/synch/test/TestSynchCritical.c index ca6f2017a..294d003ee 100644 --- a/winpr/libwinpr/synch/test/TestSynchCritical.c +++ b/winpr/libwinpr/synch/test/TestSynchCritical.c @@ -152,7 +152,7 @@ static PVOID TestSynchCritical_Main(PVOID arg) InitializeCriticalSection(&critical); - for (i=0; i<1000; i++) + for (i = 0; i < 1000; i++) { if (critical.RecursionCount != i) { @@ -200,9 +200,9 @@ static PVOID TestSynchCritical_Main(PVOID arg) dwThreadCount = sysinfo.dwNumberOfProcessors > 1 ? sysinfo.dwNumberOfProcessors : 2; - hThreads = (HANDLE*)calloc(dwThreadCount, sizeof(HANDLE)); + hThreads = (HANDLE*) calloc(dwThreadCount, sizeof(HANDLE)); - for (j=0; j < TEST_SYNC_CRITICAL_TEST1_RUNS; j++) + for (j = 0; j < TEST_SYNC_CRITICAL_TEST1_RUNS; j++) { dwSpinCount = j * 1000; InitializeCriticalSectionAndSpinCount(&critical, dwSpinCount); @@ -212,14 +212,15 @@ static PVOID TestSynchCritical_Main(PVOID arg) /* the TestSynchCritical_Test1 threads shall run until bTest1Running is FALSE */ bTest1Running = TRUE; - for (i=0; i +#include + +#if defined(__linux__) && !defined(__ANDROID__) +#define _GNU_SOURCE +#include +#include +#include +#endif #include "thread.h" @@ -101,6 +109,7 @@ HANDLE CreateThread(LPSECURITY_ATTRIBUTES lpThreadAttributes, SIZE_T dwStackSize WINPR_THREAD* thread; thread = (WINPR_THREAD*) calloc(1, sizeof(WINPR_THREAD)); + if (!thread) return NULL; @@ -155,9 +164,15 @@ HANDLE _GetCurrentThread(VOID) DWORD GetCurrentThreadId(VOID) { +#if defined(__linux__) && !defined(__ANDROID__) + pid_t tid; + tid = syscall(SYS_gettid); + return (DWORD) tid; +#else pthread_t tid; tid = pthread_self(); return (DWORD) tid; +#endif } DWORD ResumeThread(HANDLE hThread) diff --git a/winpr/libwinpr/timezone/timezone.c b/winpr/libwinpr/timezone/timezone.c index 229dc9f0e..f7af0cebc 100644 --- a/winpr/libwinpr/timezone/timezone.c +++ b/winpr/libwinpr/timezone/timezone.c @@ -41,6 +41,9 @@ #ifndef _WIN32 - +BOOL WINAPI FileTimeToSystemTime(const FILETIME *lpFileTime, LPSYSTEMTIME lpSystemTime) +{ + return FALSE; /* unimplemented */ +} #endif diff --git a/winpr/libwinpr/utils/wlog/TextMessage.c b/winpr/libwinpr/utils/wlog/TextMessage.c index 594a79901..e5f110896 100644 --- a/winpr/libwinpr/utils/wlog/TextMessage.c +++ b/winpr/libwinpr/utils/wlog/TextMessage.c @@ -25,3 +25,7 @@ #include "wlog/TextMessage.h" +void wlog_TextMessage_dummy() +{ + /* avoid no symbol ranlib warning */ +} diff --git a/winpr/libwinpr/wnd/test/TestWndCreateWindowEx.c b/winpr/libwinpr/wnd/test/TestWndCreateWindowEx.c index 375dc3745..c9898ac64 100644 --- a/winpr/libwinpr/wnd/test/TestWndCreateWindowEx.c +++ b/winpr/libwinpr/wnd/test/TestWndCreateWindowEx.c @@ -1,6 +1,7 @@ #include #include +#include #include #include diff --git a/winpr/libwinpr/wnd/test/TestWndWmCopyData.c b/winpr/libwinpr/wnd/test/TestWndWmCopyData.c index 303c32cdb..ab075356f 100644 --- a/winpr/libwinpr/wnd/test/TestWndWmCopyData.c +++ b/winpr/libwinpr/wnd/test/TestWndWmCopyData.c @@ -1,6 +1,7 @@ #include #include +#include #include static LRESULT CALLBACK TestWndProc(HWND hwnd, UINT uMsg, WPARAM wParam, LPARAM lParam) diff --git a/winpr/libwinpr/wtsapi/test/TestWtsApiQuerySessionInformation.c b/winpr/libwinpr/wtsapi/test/TestWtsApiQuerySessionInformation.c index 8b9f63ebe..7212869ea 100644 --- a/winpr/libwinpr/wtsapi/test/TestWtsApiQuerySessionInformation.c +++ b/winpr/libwinpr/wtsapi/test/TestWtsApiQuerySessionInformation.c @@ -31,8 +31,6 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) for (index = 0; index < count; index++) { - pBuffer = NULL; - bytesReturned = 0; char* Username; char* Domain; char* ClientName; @@ -44,6 +42,9 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) PWTS_CLIENT_ADDRESS ClientAddress; WTS_CONNECTSTATE_CLASS ConnectState; + pBuffer = NULL; + bytesReturned = 0; + sessionId = pSessionInfo[index].SessionId; printf("[%d] SessionId: %d State: %d\n", (int) index, @@ -52,7 +53,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) /* WTSUserName */ - bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSUserName, &pBuffer, &bytesReturned); + bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSUserName, &pBuffer, &bytesReturned); if (!bSuccess) { @@ -65,7 +66,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) /* WTSDomainName */ - bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSDomainName, &pBuffer, &bytesReturned); + bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSDomainName, &pBuffer, &bytesReturned); if (!bSuccess) { @@ -78,7 +79,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) /* WTSConnectState */ - bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSConnectState, &pBuffer, &bytesReturned); + bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSConnectState, &pBuffer, &bytesReturned); if (!bSuccess) { @@ -91,7 +92,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) /* WTSClientBuildNumber */ - bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSClientBuildNumber, &pBuffer, &bytesReturned); + bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSClientBuildNumber, &pBuffer, &bytesReturned); if (!bSuccess) { @@ -104,7 +105,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) /* WTSClientName */ - bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSClientName, &pBuffer, &bytesReturned); + bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSClientName, &pBuffer, &bytesReturned); if (!bSuccess) { @@ -117,7 +118,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) /* WTSClientProductId */ - bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSClientProductId, &pBuffer, &bytesReturned); + bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSClientProductId, &pBuffer, &bytesReturned); if (!bSuccess) { @@ -130,7 +131,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) /* WTSClientHardwareId */ - bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSClientHardwareId, &pBuffer, &bytesReturned); + bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSClientHardwareId, &pBuffer, &bytesReturned); if (!bSuccess) { @@ -143,7 +144,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) /* WTSClientAddress */ - bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSClientAddress, &pBuffer, &bytesReturned); + bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSClientAddress, &pBuffer, &bytesReturned); if (!bSuccess) { @@ -157,7 +158,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) /* WTSClientDisplay */ - bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSClientDisplay, &pBuffer, &bytesReturned); + bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSClientDisplay, &pBuffer, &bytesReturned); if (!bSuccess) { @@ -172,7 +173,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) /* WTSClientProtocolType */ - bSuccess = WTSQuerySessionInformation(hServer, sessionId, WTSClientProtocolType, &pBuffer, &bytesReturned); + bSuccess = WTSQuerySessionInformationA(hServer, sessionId, WTSClientProtocolType, &pBuffer, &bytesReturned); if (!bSuccess) {