view ui/downloader_win.cpp @ 502:e551de11d8b6

Properly handle the case that the file does not exist. TRUNCATE makes create file fail if the file does not exist but we need TRUNCATE in the case that the file already exists
author Andre Heinecke <aheinecke@intevation.de>
date Mon, 28 Apr 2014 09:18:07 +0000
parents 09bb19e5e369
children 82fab0c689bf
line wrap: on
line source
/* Copyright (C) 2014 by Bundesamt für Sicherheit in der Informationstechnik
 * Software engineering by Intevation GmbH
 *
 * This file is Free Software under the GNU GPL (v>=2)
 * and comes with ABSOLUTELY NO WARRANTY!
 * See LICENSE.txt for details.
 */
/**
 * @file downloader_win.cpp
 * @brief Downloader implementation for Windows
 *
 * We use Windows API here instead of Qt because we want to avoid
 * QtNetworks SSL stack which is based on OpenSSL and so
 * we might be incompatible with GPL code. Also using the
 * native API means that the security of the SSL implementation
 * is tied to the security of the system.
 *
 */
#include "downloader.h"
#ifdef Q_OS_WIN
#ifndef MYVERSION
#define MYVERSION "1"
#endif

#include <windows.h>
#include <winhttp.h>

#include <QDebug>
#include <QDateTime>
#include <QSaveFile>
#include <QFileInfo>

#define DEBUG if (1) qDebug() << __PRETTY_FUNCTION__

#define MAX_SW_SIZE 10485760
#define MAX_LIST_SIZE 1048576
/** @brief Download a file from the Internet
 *
 * @param[in] HSession the session to work in.
 * @param[in] HConnect the connection to use.
 * @param[in] resource the resource to download.
 * @param[in] filename where the file should be saved.
 * @param[in] maxSize maximum amount of bytes to download
 *
 * @returns True if the download was successful.
 */
bool downloadFile(HINTERNET hSession, HINTERNET hConnect,
        LPCWSTR resource, const QString &filename, DWORD maxSize);

/** @brief get the last modified header of a resource.
 *
 * On error call getLastError to get extended error information.
 * This function still does not do any networking but only initializes
 * it.
 *
 * @param[in] HSession the session to work in.
 * @param[in] HConnect the connection to use.
 * @param[in] resource the resource to check the last-modified date on
 *
 * @returns the last modified date or a null datetime in case of errors
 */
QDateTime getLastModifiedHeader(HINTERNET hSession,
    HINTERNET hConnect, LPCWSTR resource);

/** @brief verify that the certificate of the request matches
 *
 * Validates the certificate against the member variable certificate
 *
 * @param[in] hRequest: The request from which to get the certificate
 *
 * @returns True if the certificate exactly matches the one in hRequest
 */

bool verifyCertificate(HINTERNET hRequest);


#define LIST_RESOURCE "/incoming/aheinecke/test"
#define SW_RESOURCE "/incoming/aheinecke/test"

/** @brief A wrapper around a HINTERNET structure that handles closing
 *
 * Holds a HINTERNET structure and closes it if necessary on destruction.
 *
 */
class SmartHINTERNET {
public:
    SmartHINTERNET() : handle(NULL) {}

    ~SmartHINTERNET() {
        if (handle) {
            WinHttpCloseHandle(handle);
        }
    }

    HINTERNET handle;
};

/** @brief Qt wrapper around FormatMessage
 *
 * @returns The error message of the error that occurred
 */
const QString getLastErrorMsg() {
    LPWSTR bufPtr = NULL;
    DWORD err = GetLastError();
    FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER |
                   FORMAT_MESSAGE_FROM_SYSTEM |
                   FORMAT_MESSAGE_IGNORE_INSERTS,
                   NULL, err, 0, (LPWSTR)&bufPtr, 0, NULL);
    if (!bufPtr) {
        HMODULE hWinhttp = GetModuleHandleW(L"winhttp");
        if (hWinhttp) {
            FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER |
                           FORMAT_MESSAGE_FROM_HMODULE |
                           FORMAT_MESSAGE_IGNORE_INSERTS,
                           hWinhttp, HRESULT_CODE(err), 0,
                           (LPWSTR)&bufPtr, 0, NULL);
        }
    }
    const QString result =
        (bufPtr) ? QString::fromUtf16((const ushort*)bufPtr).trimmed() :
                   QString("Unknown Error %1").arg(err);
    LocalFree(bufPtr);
    return result;
}


/** @brief open a session with appropriate proxy settings
 *
 * @param[inout] *pHSession pointer to a HInternet structure
 *
 * On error call getLastError to get extended error information.
 *
 * @returns True on success, false on error.
 */


bool openSession(HINTERNET *pHSession)
{
    WINHTTP_CURRENT_USER_IE_PROXY_CONFIG proxyConfig;

    DEBUG;
    if (!pHSession) {
        SetLastError(ERROR_INVALID_PARAMETER);
        return false;
    }

    memset(&proxyConfig, 0, sizeof (WINHTTP_CURRENT_USER_IE_PROXY_CONFIG));

    if (WinHttpGetIEProxyConfigForCurrentUser(&proxyConfig)) {
        if (proxyConfig.fAutoDetect) {
            // TODO Handle this
            qDebug() << "Autodetect is set";
        }

        if (proxyConfig.lpszProxy || proxyConfig.lpszProxyBypass) {
            DEBUG << "Using proxies.";
        }

        if (proxyConfig.lpszProxy) {
            *pHSession = WinHttpOpen(L"TrustBridge "MYVERSION,
                                     WINHTTP_ACCESS_TYPE_NAMED_PROXY,
                                     proxyConfig.lpszProxy,
                                     proxyConfig.lpszProxyBypass, 0);
        }
    }

    if (!*pHSession) {
        DEBUG << "No IE Proxy falling back to default proxy";
        *pHSession = WinHttpOpen(L"TrustBridge "MYVERSION,
                                 WINHTTP_ACCESS_TYPE_DEFAULT_PROXY,
                                 WINHTTP_NO_PROXY_NAME,
                                 WINHTTP_NO_PROXY_BYPASS, 0);
    }
    // Cleanup
    if (proxyConfig.lpszAutoConfigUrl) {
        GlobalFree(proxyConfig.lpszAutoConfigUrl);
    }

    if (proxyConfig.lpszProxy) {
        GlobalFree(proxyConfig.lpszProxy);
    }

    if (proxyConfig.lpszProxyBypass) {
        GlobalFree(proxyConfig.lpszProxyBypass);
    }
    return *pHSession;
}


/** @brief initialize a connection in the session
 *
 * @param[in] HSession the session to work in.
 * @param[inout] *pHConnect pointer to the connection.
 * @param[in] url pointer to the URL in wchar representation.
 *
 * On error call getLastError to get extended error information.
 *
 * @returns True on success, false on error.
 */
bool initializeConnection(HINTERNET hSession, HINTERNET *pHConnect,
        LPCWSTR url)
{
    DEBUG;
    if (!hSession || !pHConnect) {
        SetLastError(ERROR_INVALID_PARAMETER);
        return false;
    }
    // Initialize connection. No request is done here.
    *pHConnect = WinHttpConnect(hSession, url,
                                INTERNET_DEFAULT_HTTPS_PORT, 0);

    return *pHConnect;
}

/** @brief Create a request
 *
 * @param[in] HSession the session to work in.
 * @param[in] HConnect the connection to use.
 * @param[inout] *pHRequest pointer to the request structure to be filled.
 * @param[in] requestType the HTTP request to be made default is GET
 * @param[in] resource pointer to the resource to request in wchar
 *            representation.
 *
 * On error call getLastError to get extended error information.
 * This function still does not do any networking but only initializes
 * it.
 *
 * @returns True on success, false on error.
 */

bool createRequest(HINTERNET hSession, HINTERNET hConnect,
        HINTERNET *pHRequest, LPCWSTR requestType, LPCWSTR resource)
{
    DWORD dwSSLFlag;
    DEBUG;
    if (!hSession || !hConnect || !pHRequest) {
        SetLastError(ERROR_INVALID_PARAMETER);
        return false;
    }

    *pHRequest = WinHttpOpenRequest(hConnect, requestType, resource,
                                    NULL, WINHTTP_NO_REFERER,
                                    WINHTTP_DEFAULT_ACCEPT_TYPES,
                                    WINHTTP_FLAG_SECURE);

    dwSSLFlag = SECURITY_FLAG_IGNORE_UNKNOWN_CA;
    dwSSLFlag |= SECURITY_FLAG_IGNORE_CERT_DATE_INVALID;
    dwSSLFlag |= SECURITY_FLAG_IGNORE_CERT_CN_INVALID;
    dwSSLFlag |= SECURITY_FLAG_IGNORE_CERT_WRONG_USAGE;

    WinHttpSetOption(*pHRequest, WINHTTP_OPTION_SECURITY_FLAGS,
                     &dwSSLFlag, sizeof(dwSSLFlag));

    return *pHRequest;
}

bool Downloader::verifyCertificate(HINTERNET hRequest)
{
    CERT_CONTEXT *certContext = NULL;
    DWORD certContextLen = sizeof(CERT_CONTEXT);
    bool retval = false;

    if (!WinHttpQueryOption(hRequest,
                            WINHTTP_OPTION_SERVER_CERT_CONTEXT,
                            &certContext,
                            &certContextLen)) {
        DEBUG << "Unable to get server certificate";
        return false;
    }

    QByteArray serverCert ((const char *) certContext->pbCertEncoded,
                           certContext->cbCertEncoded);

    retval = (serverCert == mCert);

    if (!retval) {
        DEBUG << "Certificate is not the same as the pinned one!"
              << "Base64 cert: " << serverCert.toBase64();
        emit error("Invalid certificate", InvalidCertificate);
    }

    CertFreeCertificateContext(certContext);
    return retval;
}

QDateTime Downloader::getLastModifiedHeader(HINTERNET hSession,
        HINTERNET hConnect, LPCWSTR resource)
{
    SmartHINTERNET sHRequest;
    SYSTEMTIME lMod;
    DWORD sizeOfSystemtime = sizeof (SYSTEMTIME);
    QDateTime retval;

    memset(&lMod, 0, sizeof (SYSTEMTIME));

    if (!hSession || !hConnect || !resource) {
        SetLastError(ERROR_INVALID_PARAMETER);
        return retval;
    }

    if (!createRequest(hSession, hConnect, &sHRequest.handle, L"HEAD",
                resource)) {
        return retval;
    }

    if (!WinHttpSendRequest(sHRequest.handle,
                            WINHTTP_NO_ADDITIONAL_HEADERS,
                            0, WINHTTP_NO_REQUEST_DATA, 0,
                            0, 0)) {
        return retval;
    }


    if (!WinHttpReceiveResponse(sHRequest.handle, NULL)) {
        return retval;
    }

    if (!verifyCertificate(sHRequest.handle)) {
        DEBUG << "Certificate verification failed";
        return retval;
    }

    if (!(WinHttpQueryHeaders(sHRequest.handle,
                              WINHTTP_QUERY_LAST_MODIFIED |
                              WINHTTP_QUERY_FLAG_SYSTEMTIME,
                              NULL,
                              &lMod,
                              &sizeOfSystemtime,
                              WINHTTP_NO_HEADER_INDEX))) {
        return retval;
    }

    retval = QDateTime(QDate(lMod.wYear, lMod.wMonth, lMod.wDay),
                       QTime(lMod.wHour, lMod.wMinute, lMod.wSecond,
                             lMod.wMilliseconds),
                       Qt::UTC);
    return retval;
}

bool Downloader::downloadFile(HINTERNET hSession, HINTERNET hConnect,
        LPCWSTR resource, const QString &fileName, DWORD maxSize)
{
    SmartHINTERNET sHRequest;
    bool retval = false;
    DWORD bytesAvailable = 0,
          bytesRead = 0,
          totalDownloaded = 0,
          contentLength = 0,
          sizeOfDWORD = sizeof (DWORD);

    QSaveFile outputFile(fileName);

    if (!hSession || !hConnect || !resource) {
        SetLastError(ERROR_INVALID_PARAMETER);
        return retval;
    }

    if (!createRequest(hSession, hConnect, &sHRequest.handle, L"GET",
                resource)) {
        return retval;
    }

    if (!WinHttpSendRequest(sHRequest.handle,
                            WINHTTP_NO_ADDITIONAL_HEADERS,
                            0, WINHTTP_NO_REQUEST_DATA, 0,
                            0, 0)) {
        return retval;
    }


    if (!WinHttpReceiveResponse(sHRequest.handle, NULL)) {
        return retval;
    }

    if (!verifyCertificate(sHRequest.handle)) {
        DEBUG << "Certificate verification failed";
        return retval;
    }


    if (!(WinHttpQueryHeaders(sHRequest.handle,
                              WINHTTP_QUERY_CONTENT_LENGTH |
                              WINHTTP_QUERY_FLAG_NUMBER,
                              NULL,
                              &contentLength,
                              &sizeOfDWORD,
                              WINHTTP_NO_HEADER_INDEX))) {
        // Continue anyway as we later really check how
        // much we download.
        DEBUG << "No content-length";
    }

    if (contentLength > maxSize) {
        return retval;
    }

    if (contentLength) {
        QFileInfo finf(fileName);
        if (finf.exists() && finf.isReadable() &&
                finf.size() == contentLength) {
            // We already have data of the same size
            // No need to waste bandwidth.
            DEBUG << "Skipping download because file exists";
            retval = true;
            return retval;
        }
    }

    // Open / Create the file to write to.
    if (!outputFile.open(QIODevice::WriteOnly)) {
        DEBUG << "Failed to open file";
        return retval;
    }

    DEBUG << "output file size: " << outputFile.size();
    do
    {
        char outBuf[8192]; // 8KB is the internal buffer size of winhttp
        memset(outBuf, 0, sizeof(outBuf));
        bytesRead = 0;

        if (!WinHttpQueryDataAvailable(sHRequest.handle, &bytesAvailable)) {
            DEBUG << "Querying for available data failed";
            retval = false;
            break;
        }

        if (!bytesAvailable) {
            // Might indicate that we are done.
            break;
        }

        if (bytesAvailable > maxSize) {
            DEBUG << "File to large";
            retval = false;
            SetLastError(ERROR_INVALID_DATA);
            break;
        }

        if (!WinHttpReadData(sHRequest.handle, (LPVOID)outBuf,
                             sizeof(outBuf), &bytesRead)) {
            DEBUG << "Error reading data";
            break;
        } else {
            if (bytesRead) {
                // Write data to file.
                if (outputFile.write(outBuf, bytesRead) !=
                        bytesRead) {
                    DEBUG << "Error writing to file.";
                    retval = false;
                }
                // Completed a read / write cycle. If not error follows
                // the download was successful.
                retval = true;
            } else {
                // Should not happen as we queried for available
                // bytes before and the function did not return an
                // error.
                DEBUG << "Unable to read available data";
                retval = false;
                break;
            }
        }
        totalDownloaded += bytesRead;

        if (totalDownloaded > maxSize) {
            DEBUG << "Downloaded too much data. Breaking.";
            retval = false;
            break;
        }
    } while (bytesAvailable > 0);

    if (retval && outputFile.isOpen()) {
        // Actually save the file to disk / move to homedir
        retval = outputFile.commit();
    }

    return retval;
}

void Downloader::run() {
    bool results = false;
    SmartHINTERNET sHSession;
    SmartHINTERNET sHConnect;
    wchar_t wUrl[mUrl.size() + 1];
    QDateTime lastModifiedSoftware;
    QDateTime lastModifiedList;

    int rc = 0;

    memset(wUrl, 0, sizeof (wchar_t) * (mUrl.size() + 1));

    rc = mUrl.toWCharArray(wUrl);

    if (rc != mUrl.size()) {
        DEBUG << "Problem converting to wchar array";
        return;
    }

    // Should not be necessary because we initialized the memory
    wUrl[rc] = '\0';

    // Initialize connection
    if (!openSession(&sHConnect.handle)) {
        DEBUG << "Failed to open session: " << getLastErrorMsg();
        return;
    }
    if (!initializeConnection(sHConnect.handle, &sHConnect.handle, wUrl)) {
        DEBUG << "Failed to initialize connection: " << getLastErrorMsg();
        return;
    }


    lastModifiedSoftware = getLastModifiedHeader(sHConnect.handle, sHConnect.handle,
            L""SW_RESOURCE);

    lastModifiedList = getLastModifiedHeader(sHConnect.handle, sHConnect.handle,
            L""LIST_RESOURCE);

    if (!lastModifiedList.isValid() || !lastModifiedSoftware.isValid()) {
        DEBUG << "Could not read headers: " << getLastErrorMsg();
        return;
    }

    if (!mLastModSW.isValid() || lastModifiedSoftware > mLastModSW) {
        QString dataDirectory = getDataDirectory();

        if (dataDirectory.isEmpty()) {
            DEBUG << "Failed to get data directory";
            return;
        }

        QString fileName = dataDirectory.append("/SW-")
            .append(lastModifiedSoftware.toString("yyyymmddHHmmss"))
            .append(".exe");

        DEBUG << "fileName: " << fileName;

        if (!downloadFile(sHConnect.handle, sHConnect.handle, L""SW_RESOURCE,
                   fileName, MAX_SW_SIZE)) {
            DEBUG << "Error downloading File: " << getLastErrorMsg();
            return;
        }

        emit newSoftwareAvailable(fileName, lastModifiedSoftware);
    } else if (!mLastModList.isValid() || lastModifiedList > mLastModList) {
        QString dataDirectory = getDataDirectory();

        if (dataDirectory.isEmpty()) {
            DEBUG << "Failed to get data directory";
            return;
        }

        QString fileName = dataDirectory.append("/list-")
            .append(lastModifiedSoftware.toString("yyyymmddHHmmss"))
            .append(".txt");

        DEBUG << "fileName: " << fileName;

        if (!downloadFile(sHConnect.handle, sHConnect.handle, L""LIST_RESOURCE,
                   fileName, MAX_LIST_SIZE)) {
            DEBUG << "Error downloading File: " << getLastErrorMsg();
            return;
        }

        emit newListAvailable(fileName, lastModifiedList);
    }

    DEBUG << "SW date: " << lastModifiedSoftware;
    DEBUG << "List date: " << lastModifiedList;

    if (!results) {
        // Report any errors.
        DEBUG << "Error" << GetLastError();
        emit error(tr("Unknown Problem when connecting"), ErrUnknown);
    }

    return;
}
#endif

http://wald.intevation.org/projects/trustbridge/