view ui/downloader_win.cpp @ 1243:cf5784d2c3a8

(issue54) Safeguard to prohibit starting the application as root
author Andre Heinecke <andre.heinecke@intevation.de>
date Wed, 24 Sep 2014 19:22:47 +0200
parents 09bb19e5e369
children 82fab0c689bf
line wrap: on
line source
/* Copyright (C) 2014 by Bundesamt für Sicherheit in der Informationstechnik
 * Software engineering by Intevation GmbH
 *
 * This file is Free Software under the GNU GPL (v>=2)
 * and comes with ABSOLUTELY NO WARRANTY!
 * See LICENSE.txt for details.
 */
/**
 * @file downloader_win.cpp
 * @brief Downloader implementation for Windows
 *
 * We use Windows API here instead of Qt because we want to avoid
 * QtNetworks SSL stack which is based on OpenSSL and so
 * we might be incompatible with GPL code. Also using the
 * native API means that the security of the SSL implementation
 * is tied to the security of the system.
 *
 */
#include "downloader.h"
#ifdef Q_OS_WIN
#ifndef MYVERSION
#define MYVERSION "1"
#endif

#include <windows.h>
#include <winhttp.h>

#include <QDebug>
#include <QDateTime>
#include <QSaveFile>
#include <QFileInfo>

#define DEBUG if (1) qDebug() << __PRETTY_FUNCTION__

#define MAX_SW_SIZE 10485760
#define MAX_LIST_SIZE 1048576
/** @brief Download a file from the Internet
 *
 * @param[in] HSession the session to work in.
 * @param[in] HConnect the connection to use.
 * @param[in] resource the resource to download.
 * @param[in] filename where the file should be saved.
 * @param[in] maxSize maximum amount of bytes to download
 *
 * @returns True if the download was successful.
 */
bool downloadFile(HINTERNET hSession, HINTERNET hConnect,
        LPCWSTR resource, const QString &filename, DWORD maxSize);

/** @brief get the last modified header of a resource.
 *
 * On error call getLastError to get extended error information.
 * This function still does not do any networking but only initializes
 * it.
 *
 * @param[in] HSession the session to work in.
 * @param[in] HConnect the connection to use.
 * @param[in] resource the resource to check the last-modified date on
 *
 * @returns the last modified date or a null datetime in case of errors
 */
QDateTime getLastModifiedHeader(HINTERNET hSession,
    HINTERNET hConnect, LPCWSTR resource);

/** @brief verify that the certificate of the request matches
 *
 * Validates the certificate against the member variable certificate
 *
 * @param[in] hRequest: The request from which to get the certificate
 *
 * @returns True if the certificate exactly matches the one in hRequest
 */

bool verifyCertificate(HINTERNET hRequest);


#define LIST_RESOURCE "/incoming/aheinecke/test"
#define SW_RESOURCE "/incoming/aheinecke/test"

/** @brief A wrapper around a HINTERNET structure that handles closing
 *
 * Holds a HINTERNET structure and closes it if necessary on destruction.
 *
 */
class SmartHINTERNET {
public:
    SmartHINTERNET() : handle(NULL) {}

    ~SmartHINTERNET() {
        if (handle) {
            WinHttpCloseHandle(handle);
        }
    }

    HINTERNET handle;
};

/** @brief Qt wrapper around FormatMessage
 *
 * @returns The error message of the error that occurred
 */
const QString getLastErrorMsg() {
    LPWSTR bufPtr = NULL;
    DWORD err = GetLastError();
    FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER |
                   FORMAT_MESSAGE_FROM_SYSTEM |
                   FORMAT_MESSAGE_IGNORE_INSERTS,
                   NULL, err, 0, (LPWSTR)&bufPtr, 0, NULL);
    if (!bufPtr) {
        HMODULE hWinhttp = GetModuleHandleW(L"winhttp");
        if (hWinhttp) {
            FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER |
                           FORMAT_MESSAGE_FROM_HMODULE |
                           FORMAT_MESSAGE_IGNORE_INSERTS,
                           hWinhttp, HRESULT_CODE(err), 0,
                           (LPWSTR)&bufPtr, 0, NULL);
        }
    }
    const QString result =
        (bufPtr) ? QString::fromUtf16((const ushort*)bufPtr).trimmed() :
                   QString("Unknown Error %1").arg(err);
    LocalFree(bufPtr);
    return result;
}


/** @brief open a session with appropriate proxy settings
 *
 * @param[inout] *pHSession pointer to a HInternet structure
 *
 * On error call getLastError to get extended error information.
 *
 * @returns True on success, false on error.
 */


bool openSession(HINTERNET *pHSession)
{
    WINHTTP_CURRENT_USER_IE_PROXY_CONFIG proxyConfig;

    DEBUG;
    if (!pHSession) {
        SetLastError(ERROR_INVALID_PARAMETER);
        return false;
    }

    memset(&proxyConfig, 0, sizeof (WINHTTP_CURRENT_USER_IE_PROXY_CONFIG));

    if (WinHttpGetIEProxyConfigForCurrentUser(&proxyConfig)) {
        if (proxyConfig.fAutoDetect) {
            // TODO Handle this
            qDebug() << "Autodetect is set";
        }

        if (proxyConfig.lpszProxy || proxyConfig.lpszProxyBypass) {
            DEBUG << "Using proxies.";
        }

        if (proxyConfig.lpszProxy) {
            *pHSession = WinHttpOpen(L"TrustBridge "MYVERSION,
                                     WINHTTP_ACCESS_TYPE_NAMED_PROXY,
                                     proxyConfig.lpszProxy,
                                     proxyConfig.lpszProxyBypass, 0);
        }
    }

    if (!*pHSession) {
        DEBUG << "No IE Proxy falling back to default proxy";
        *pHSession = WinHttpOpen(L"TrustBridge "MYVERSION,
                                 WINHTTP_ACCESS_TYPE_DEFAULT_PROXY,
                                 WINHTTP_NO_PROXY_NAME,
                                 WINHTTP_NO_PROXY_BYPASS, 0);
    }
    // Cleanup
    if (proxyConfig.lpszAutoConfigUrl) {
        GlobalFree(proxyConfig.lpszAutoConfigUrl);
    }

    if (proxyConfig.lpszProxy) {
        GlobalFree(proxyConfig.lpszProxy);
    }

    if (proxyConfig.lpszProxyBypass) {
        GlobalFree(proxyConfig.lpszProxyBypass);
    }
    return *pHSession;
}


/** @brief initialize a connection in the session
 *
 * @param[in] HSession the session to work in.
 * @param[inout] *pHConnect pointer to the connection.
 * @param[in] url pointer to the URL in wchar representation.
 *
 * On error call getLastError to get extended error information.
 *
 * @returns True on success, false on error.
 */
bool initializeConnection(HINTERNET hSession, HINTERNET *pHConnect,
        LPCWSTR url)
{
    DEBUG;
    if (!hSession || !pHConnect) {
        SetLastError(ERROR_INVALID_PARAMETER);
        return false;
    }
    // Initialize connection. No request is done here.
    *pHConnect = WinHttpConnect(hSession, url,
                                INTERNET_DEFAULT_HTTPS_PORT, 0);

    return *pHConnect;
}

/** @brief Create a request
 *
 * @param[in] HSession the session to work in.
 * @param[in] HConnect the connection to use.
 * @param[inout] *pHRequest pointer to the request structure to be filled.
 * @param[in] requestType the HTTP request to be made default is GET
 * @param[in] resource pointer to the resource to request in wchar
 *            representation.
 *
 * On error call getLastError to get extended error information.
 * This function still does not do any networking but only initializes
 * it.
 *
 * @returns True on success, false on error.
 */

bool createRequest(HINTERNET hSession, HINTERNET hConnect,
        HINTERNET *pHRequest, LPCWSTR requestType, LPCWSTR resource)
{
    DWORD dwSSLFlag;
    DEBUG;
    if (!hSession || !hConnect || !pHRequest) {
        SetLastError(ERROR_INVALID_PARAMETER);
        return false;
    }

    *pHRequest = WinHttpOpenRequest(hConnect, requestType, resource,
                                    NULL, WINHTTP_NO_REFERER,
                                    WINHTTP_DEFAULT_ACCEPT_TYPES,
                                    WINHTTP_FLAG_SECURE);

    dwSSLFlag = SECURITY_FLAG_IGNORE_UNKNOWN_CA;
    dwSSLFlag |= SECURITY_FLAG_IGNORE_CERT_DATE_INVALID;
    dwSSLFlag |= SECURITY_FLAG_IGNORE_CERT_CN_INVALID;
    dwSSLFlag |= SECURITY_FLAG_IGNORE_CERT_WRONG_USAGE;

    WinHttpSetOption(*pHRequest, WINHTTP_OPTION_SECURITY_FLAGS,
                     &dwSSLFlag, sizeof(dwSSLFlag));

    return *pHRequest;
}

bool Downloader::verifyCertificate(HINTERNET hRequest)
{
    CERT_CONTEXT *certContext = NULL;
    DWORD certContextLen = sizeof(CERT_CONTEXT);
    bool retval = false;

    if (!WinHttpQueryOption(hRequest,
                            WINHTTP_OPTION_SERVER_CERT_CONTEXT,
                            &certContext,
                            &certContextLen)) {
        DEBUG << "Unable to get server certificate";
        return false;
    }

    QByteArray serverCert ((const char *) certContext->pbCertEncoded,
                           certContext->cbCertEncoded);

    retval = (serverCert == mCert);

    if (!retval) {
        DEBUG << "Certificate is not the same as the pinned one!"
              << "Base64 cert: " << serverCert.toBase64();
        emit error("Invalid certificate", InvalidCertificate);
    }

    CertFreeCertificateContext(certContext);
    return retval;
}

QDateTime Downloader::getLastModifiedHeader(HINTERNET hSession,
        HINTERNET hConnect, LPCWSTR resource)
{
    SmartHINTERNET sHRequest;
    SYSTEMTIME lMod;
    DWORD sizeOfSystemtime = sizeof (SYSTEMTIME);
    QDateTime retval;

    memset(&lMod, 0, sizeof (SYSTEMTIME));

    if (!hSession || !hConnect || !resource) {
        SetLastError(ERROR_INVALID_PARAMETER);
        return retval;
    }

    if (!createRequest(hSession, hConnect, &sHRequest.handle, L"HEAD",
                resource)) {
        return retval;
    }

    if (!WinHttpSendRequest(sHRequest.handle,
                            WINHTTP_NO_ADDITIONAL_HEADERS,
                            0, WINHTTP_NO_REQUEST_DATA, 0,
                            0, 0)) {
        return retval;
    }


    if (!WinHttpReceiveResponse(sHRequest.handle, NULL)) {
        return retval;
    }

    if (!verifyCertificate(sHRequest.handle)) {
        DEBUG << "Certificate verification failed";
        return retval;
    }

    if (!(WinHttpQueryHeaders(sHRequest.handle,
                              WINHTTP_QUERY_LAST_MODIFIED |
                              WINHTTP_QUERY_FLAG_SYSTEMTIME,
                              NULL,
                              &lMod,
                              &sizeOfSystemtime,
                              WINHTTP_NO_HEADER_INDEX))) {
        return retval;
    }

    retval = QDateTime(QDate(lMod.wYear, lMod.wMonth, lMod.wDay),
                       QTime(lMod.wHour, lMod.wMinute, lMod.wSecond,
                             lMod.wMilliseconds),
                       Qt::UTC);
    return retval;
}

bool Downloader::downloadFile(HINTERNET hSession, HINTERNET hConnect,
        LPCWSTR resource, const QString &fileName, DWORD maxSize)
{
    SmartHINTERNET sHRequest;
    bool retval = false;
    DWORD bytesAvailable = 0,
          bytesRead = 0,
          totalDownloaded = 0,
          contentLength = 0,
          sizeOfDWORD = sizeof (DWORD);

    QSaveFile outputFile(fileName);

    if (!hSession || !hConnect || !resource) {
        SetLastError(ERROR_INVALID_PARAMETER);
        return retval;
    }

    if (!createRequest(hSession, hConnect, &sHRequest.handle, L"GET",
                resource)) {
        return retval;
    }

    if (!WinHttpSendRequest(sHRequest.handle,
                            WINHTTP_NO_ADDITIONAL_HEADERS,
                            0, WINHTTP_NO_REQUEST_DATA, 0,
                            0, 0)) {
        return retval;
    }


    if (!WinHttpReceiveResponse(sHRequest.handle, NULL)) {
        return retval;
    }

    if (!verifyCertificate(sHRequest.handle)) {
        DEBUG << "Certificate verification failed";
        return retval;
    }


    if (!(WinHttpQueryHeaders(sHRequest.handle,
                              WINHTTP_QUERY_CONTENT_LENGTH |
                              WINHTTP_QUERY_FLAG_NUMBER,
                              NULL,
                              &contentLength,
                              &sizeOfDWORD,
                              WINHTTP_NO_HEADER_INDEX))) {
        // Continue anyway as we later really check how
        // much we download.
        DEBUG << "No content-length";
    }

    if (contentLength > maxSize) {
        return retval;
    }

    if (contentLength) {
        QFileInfo finf(fileName);
        if (finf.exists() && finf.isReadable() &&
                finf.size() == contentLength) {
            // We already have data of the same size
            // No need to waste bandwidth.
            DEBUG << "Skipping download because file exists";
            retval = true;
            return retval;
        }
    }

    // Open / Create the file to write to.
    if (!outputFile.open(QIODevice::WriteOnly)) {
        DEBUG << "Failed to open file";
        return retval;
    }

    DEBUG << "output file size: " << outputFile.size();
    do
    {
        char outBuf[8192]; // 8KB is the internal buffer size of winhttp
        memset(outBuf, 0, sizeof(outBuf));
        bytesRead = 0;

        if (!WinHttpQueryDataAvailable(sHRequest.handle, &bytesAvailable)) {
            DEBUG << "Querying for available data failed";
            retval = false;
            break;
        }

        if (!bytesAvailable) {
            // Might indicate that we are done.
            break;
        }

        if (bytesAvailable > maxSize) {
            DEBUG << "File to large";
            retval = false;
            SetLastError(ERROR_INVALID_DATA);
            break;
        }

        if (!WinHttpReadData(sHRequest.handle, (LPVOID)outBuf,
                             sizeof(outBuf), &bytesRead)) {
            DEBUG << "Error reading data";
            break;
        } else {
            if (bytesRead) {
                // Write data to file.
                if (outputFile.write(outBuf, bytesRead) !=
                        bytesRead) {
                    DEBUG << "Error writing to file.";
                    retval = false;
                }
                // Completed a read / write cycle. If not error follows
                // the download was successful.
                retval = true;
            } else {
                // Should not happen as we queried for available
                // bytes before and the function did not return an
                // error.
                DEBUG << "Unable to read available data";
                retval = false;
                break;
            }
        }
        totalDownloaded += bytesRead;

        if (totalDownloaded > maxSize) {
            DEBUG << "Downloaded too much data. Breaking.";
            retval = false;
            break;
        }
    } while (bytesAvailable > 0);

    if (retval && outputFile.isOpen()) {
        // Actually save the file to disk / move to homedir
        retval = outputFile.commit();
    }

    return retval;
}

void Downloader::run() {
    bool results = false;
    SmartHINTERNET sHSession;
    SmartHINTERNET sHConnect;
    wchar_t wUrl[mUrl.size() + 1];
    QDateTime lastModifiedSoftware;
    QDateTime lastModifiedList;

    int rc = 0;

    memset(wUrl, 0, sizeof (wchar_t) * (mUrl.size() + 1));

    rc = mUrl.toWCharArray(wUrl);

    if (rc != mUrl.size()) {
        DEBUG << "Problem converting to wchar array";
        return;
    }

    // Should not be necessary because we initialized the memory
    wUrl[rc] = '\0';

    // Initialize connection
    if (!openSession(&sHConnect.handle)) {
        DEBUG << "Failed to open session: " << getLastErrorMsg();
        return;
    }
    if (!initializeConnection(sHConnect.handle, &sHConnect.handle, wUrl)) {
        DEBUG << "Failed to initialize connection: " << getLastErrorMsg();
        return;
    }


    lastModifiedSoftware = getLastModifiedHeader(sHConnect.handle, sHConnect.handle,
            L""SW_RESOURCE);

    lastModifiedList = getLastModifiedHeader(sHConnect.handle, sHConnect.handle,
            L""LIST_RESOURCE);

    if (!lastModifiedList.isValid() || !lastModifiedSoftware.isValid()) {
        DEBUG << "Could not read headers: " << getLastErrorMsg();
        return;
    }

    if (!mLastModSW.isValid() || lastModifiedSoftware > mLastModSW) {
        QString dataDirectory = getDataDirectory();

        if (dataDirectory.isEmpty()) {
            DEBUG << "Failed to get data directory";
            return;
        }

        QString fileName = dataDirectory.append("/SW-")
            .append(lastModifiedSoftware.toString("yyyymmddHHmmss"))
            .append(".exe");

        DEBUG << "fileName: " << fileName;

        if (!downloadFile(sHConnect.handle, sHConnect.handle, L""SW_RESOURCE,
                   fileName, MAX_SW_SIZE)) {
            DEBUG << "Error downloading File: " << getLastErrorMsg();
            return;
        }

        emit newSoftwareAvailable(fileName, lastModifiedSoftware);
    } else if (!mLastModList.isValid() || lastModifiedList > mLastModList) {
        QString dataDirectory = getDataDirectory();

        if (dataDirectory.isEmpty()) {
            DEBUG << "Failed to get data directory";
            return;
        }

        QString fileName = dataDirectory.append("/list-")
            .append(lastModifiedSoftware.toString("yyyymmddHHmmss"))
            .append(".txt");

        DEBUG << "fileName: " << fileName;

        if (!downloadFile(sHConnect.handle, sHConnect.handle, L""LIST_RESOURCE,
                   fileName, MAX_LIST_SIZE)) {
            DEBUG << "Error downloading File: " << getLastErrorMsg();
            return;
        }

        emit newListAvailable(fileName, lastModifiedList);
    }

    DEBUG << "SW date: " << lastModifiedSoftware;
    DEBUG << "List date: " << lastModifiedList;

    if (!results) {
        // Report any errors.
        DEBUG << "Error" << GetLastError();
        emit error(tr("Unknown Problem when connecting"), ErrUnknown);
    }

    return;
}
#endif

http://wald.intevation.org/projects/trustbridge/