/*------------------------------------------------------------------------- * clintlib\FTPSession.cpp * * The implementation of IFTPSession * * See clintlib\FTPSession.h for descriptions of all public functions. * * Owner: * * Copyright 1986-2000 Microsoft Corporation, All Rights Reserved *-----------------------------------------------------------------------*/ #include "pch.h" template class CInternetSessionImpl : public TInterfaceClass { public: CInternetSessionImpl(IInternetSessionSink * pSink) : // pSink can be NULL m_buffer(NULL), m_hFile(NULL), m_pszFileList(NULL), m_pszFileListData(NULL), m_nLastErrorCode(0), m_nLastErrorCodeInThread(0), m_cBytesRead(0), m_bAbortDownload(false), m_hFTPSession(NULL), m_hFileConnection(NULL), m_pUpdateSink(pSink) { m_szLastError[0] = 0; m_szErrorInThread[0] = 0; m_eventResumeDownload = CreateEvent(NULL, FALSE, FALSE, NULL); m_eventKillDownload = CreateEvent(NULL, TRUE, FALSE, NULL); m_eventProgress = CreateEvent(NULL, FALSE, FALSE, NULL); m_eventDownloadTerminated = CreateEvent(NULL, TRUE, FALSE, NULL); m_eventFileCompleted = CreateEvent(NULL, FALSE, FALSE, NULL); debugf("Creating download thread.\n"); DWORD dum; m_threadDownload = CreateThread(NULL, 0, DownloadThread, (void*)this, 0, &dum); if (m_pUpdateSink && m_threadDownload == NULL) debugf("Failed to create thread.\n"); } virtual ~CInternetSessionImpl() { KillDownload(); Disconnect(); CloseHandle(m_eventResumeDownload); CloseHandle(m_eventKillDownload); CloseHandle(m_eventProgress); CloseHandle(m_eventDownloadTerminated); CloseHandle(m_eventFileCompleted); CloseHandle(m_threadDownload); if (m_pszFileList) { char * * psz = m_pszFileListData + 1; // first one is 0xFFFFFFFF, so skip it while (psz && *psz) free(*(psz++)); delete[] m_pszFileListData; } } virtual int GetFileListIncrement() = 0; void SetSink(IInternetSessionSink * pUpdateSink) { m_pUpdateSink = pUpdateSink; } bool InitiateDownload(const char * const * pszFileList, const char * szDestFolder, bool bDisconnectWhenDone, int nMaxBufferSize) { // // Prepare a memory buffer for the filelist, and copy pszFileList to m_pszFileList // { // iterate once to find number of entries char * * psz = const_cast(pszFileList); int i = 0; while (psz && *psz) { i++; psz++; } m_pszFileListData = new char*[i+2]; m_pszFileListData[0] = (char*)-1; // means not started yet; // StartNextFile() increments first thing; so (char*)-1 is not downloaded, (I chose -1 because it's different than NULL) // iterate again and allocate memory for each entry psz = const_cast(pszFileList); i = 1; while (psz && *psz) { m_pszFileListData[i] = _strdup(*psz); i++; psz++; } m_pszFileListData[i] = NULL; m_pszFileList = m_pszFileListData; } strcpy(m_szDestFolder, szDestFolder); if (m_szDestFolder[0] == '\0' || // prevent access of not our memory m_szDestFolder[strlen(m_szDestFolder)-1] != '\\') { strcat(m_szDestFolder, "\\"); } m_bAutoDisconnect = bDisconnectWhenDone; m_nBufferSize = nMaxBufferSize; if(m_buffer == NULL) m_buffer = (char*)::VirtualAlloc(NULL, m_nBufferSize, MEM_COMMIT, PAGE_READWRITE); assert(m_buffer); m_cTotalBytesRead = 0; ResetEvent(m_eventResumeDownload); ResetEvent(m_eventKillDownload); ResetEvent(m_eventProgress); ResetEvent(m_eventDownloadTerminated); ResetEvent(m_eventFileCompleted); return true; } bool Disconnect() { FinishCurrentFile(false); if (m_hInternetSession) { if (!InternetCloseHandle(m_hInternetSession)) { DoError("Disconnect Failed"); return false; } m_hInternetSession = NULL; } if (m_buffer) { ::VirtualFree((void*)m_buffer, 0, MEM_RELEASE); m_buffer = NULL; } return true; } const char* GetDownloadPath() { return m_szDestFolder; } const char* GetLastErrorMessage() { if (m_szLastError[0] != '\0') { return m_szLastError; } else return NULL; // no error has occured } void Abort(bool bAutoDisconnect) { KillDownload(); m_bAbortDownload = true; if (bAutoDisconnect) Disconnect(); } bool ContinueDownload() { if (m_szLastError[0] != '\0') // if aborted or previous error return false; // if error within download thread if (WaitForSingleObject(m_eventKillDownload, 0) == WAIT_OBJECT_0) { if (m_szErrorInThread[0] != 0) { SetLastError(m_nLastErrorCodeInThread); DoError(m_szErrorInThread); } else if (m_bAutoDisconnect) Disconnect(); return false; // signal that we are done } // Note: *m_pszFileList is 0xFFFFFFFF if download thread hasn't started downloading yet if (*m_pszFileList != NULL && !m_bAbortDownload) // if not done { if (WaitForSingleObject(m_eventProgress, 0) == WAIT_OBJECT_0) { // // Flush buffer as if needed // if (m_cBytesRead == m_nBufferSize) { if (!FlushDownloadBuffer()) return false; } // If progress updates are wanted if(m_pUpdateSink) { // Fire update m_pUpdateSink->OnProgress(m_cTotalBytesRead, *m_pszFileList, m_cCurrentFileBytesRead); } SetEvent(m_eventResumeDownload); } else if(WaitForSingleObject(m_eventFileCompleted, 0) == WAIT_OBJECT_0) { FinishCurrentFile(true); SetEvent(m_eventResumeDownload); } return true; // signal NOT yet done with entire transfer } else // if done { // // Close up // // signal complete transfer if (m_pUpdateSink) m_pUpdateSink->OnTransferFinished(); if (m_bAutoDisconnect) Disconnect(); return false; } } protected: enum DOWNLOAD_RESULT { DOWNLOAD_ERROR, DOWNLOAD_PROGRESS, FILE_COMPLETED, }; static DWORD WINAPI DownloadThread(LPVOID pThreadParameter) { CInternetSessionImpl * pSession = (CInternetSessionImpl *) pThreadParameter; HANDLE pHandles[] = { pSession->m_eventKillDownload, pSession->m_eventResumeDownload }; // // Wait for file download to resume or for abort // while (WaitForMultipleObjects(2, pHandles, FALSE, INFINITE) != WAIT_OBJECT_0) { if (pSession->m_hFile == NULL) { if (!pSession->StartNextFile()) { SetEvent(pSession->m_eventKillDownload); break; } } if (pSession->m_hFile != NULL) { // DOWNLOAD_RESULT result = pSession->DownloadFileBlock(); DOWNLOAD_RESULT result = DOWNLOAD_ERROR; // default to error, until we know better __try { result = pSession->DownloadFileBlock(); } __except(1) { result = DOWNLOAD_ERROR; } if (result == DOWNLOAD_PROGRESS) { SetEvent(pSession->m_eventProgress); } else if (result == FILE_COMPLETED) { SetEvent(pSession->m_eventFileCompleted); } else if (result == DOWNLOAD_ERROR) { SetEvent(pSession->m_eventKillDownload); break; } } } debugf("Download thread exiting...\n"); SetEvent(pSession->m_eventDownloadTerminated); ExitThread(0); return 0; } virtual bool StartNextFile() { m_cBytesRead = 0; m_cCurrentFileBytesRead = 0; ++m_pszFileList; unsigned cTries = 0; if (*m_pszFileList) { // // Open file for download // while (!(m_hFileConnection = FtpOpenFile(m_hFTPSession, *m_pszFileList, GENERIC_READ, FTP_TRANSFER_TYPE_BINARY | INTERNET_FLAG_RELOAD, 0))) { // // Okay, it failed. On NT5, this seems to sometimes happen for no apparent reason, so lets try a // few more things just to be sure. // cTries++; debugf("Failed to open file via FTP for download, try #%d; error code: %d\n", cTries, GetLastError()); Sleep(500); if(cTries >= 10) // okay, it must be more than a glich--must be an error, like the file not existing { DoErrorInThread("Failed to open file (%s) for download.", *m_pszFileList); return false; } } return OpenDownloadFile(); } return true; } bool FinishCurrentFile(bool bCompleted) // finish as in close, without any downloading { if (m_hFileConnection) { if (!InternetCloseHandle(m_hFileConnection)) { DoError("InternetCloseHandle() Failed for download file"); return false; } m_hFileConnection = NULL; } // // If done, write file and exit // if (!FlushDownloadBuffer()) return false; if (!CloseDownloadFile(bCompleted)) return false; return true; } DOWNLOAD_RESULT DownloadFileBlock() { unsigned long cBytesAvail, cBytesJustRead; // // Find out is data is available for receiving // if (!InternetQueryDataAvailable(m_hFileConnection, &cBytesAvail, 0, 0)) { DoErrorInThread("InternetQueryDataAvailable() Failed."); return DOWNLOAD_ERROR; } // cBytesAvail = min(cBytesAvail, nMaxBlockSize); if (cBytesAvail == 0) { return FILE_COMPLETED; //return FinishCurrentFile(); } // // Get the available data // unsigned long cBytesAttempted = min(cBytesAvail, m_nBufferSize-m_cBytesRead); if (!InternetReadFile((void*)m_hFileConnection, (void*)(m_buffer+m_cBytesRead), cBytesAttempted, &cBytesJustRead)) { DoErrorInThread("InternetReadFile() Failed."); return DOWNLOAD_ERROR; } if (cBytesJustRead == 0) // testing proves this additional check is needed. { return FILE_COMPLETED; // return FinishCurrentFile(); } /* INTERNET_BUFFERS buffer; buffer.lpvBuffer = m_buffer+m_cBytesRead; buffer.dwBufferLength = m_nBufferSize-m_cBytesRead; buffer.dwBufferTotal = buffer.dwBufferLength; buffer.dwStructSize = sizeof(INTERNET_BUFFERS); buffer.lpcszHeader = NULL; buffer.dwHeadersLength = 0; buffer.dwHeadersTotal = 0; buffer.Next = NULL; buffer.dwOffsetLow = 0; buffer.dwOffsetHigh = 0; if (!InternetReadFileEx(m_hFileConnection, &buffer, IRF_SYNC | IRF_NO_WAIT, 0)) { DoError("FTP Read Failed."); return false; } unsigned cBytesJustRead = buffer.dwBufferTotal; */ m_cBytesRead += cBytesJustRead; m_cCurrentFileBytesRead += cBytesJustRead; m_cTotalBytesRead += cBytesJustRead; return DOWNLOAD_PROGRESS; } bool OpenDownloadFile() { // // Make Path // char szFilename[MAX_PATH+20]; strcpy(szFilename, m_szDestFolder); strcat(szFilename, *m_pszFileList); m_hFile = CreateFile(szFilename, GENERIC_WRITE, FILE_SHARE_READ, NULL, CREATE_ALWAYS, FILE_ATTRIBUTE_TEMPORARY, // don't write to disk right away (for better performance). NULL); if (m_hFile == INVALID_HANDLE_VALUE) { DoErrorInThread("Failed create file (%s) on local drive.", szFilename); return false; } return true; } bool FlushDownloadBuffer() { unsigned long cBytesWritten; if (m_cBytesRead != 0) { if(m_pUpdateSink) { if(m_pUpdateSink->OnDataReceived((void*)m_buffer, m_cBytesRead) == false) { // after looking at the data, user has decided to // skip this file. // NOTE: This hasn't been tested. ::InternetCloseHandle(m_hFileConnection); m_hFileConnection = NULL; ::CloseHandle(m_hFile); m_hFile = NULL; } } if (!WriteFile(m_hFile, (void*)m_buffer, m_cBytesRead, &cBytesWritten, NULL)) { DoError("Failed to write the file (%s) to local drive : ", *m_pszFileList); return false; } m_cBytesRead = 0; } return true; } bool CloseDownloadFile(bool bCompleted) { if (m_hFile == NULL) return true; if (!::CloseHandle(m_hFile)) DoError("Failed to close file %s", *m_pszFileList); m_hFile = NULL; // // If progress updates are wanted // if(m_pUpdateSink && bCompleted) { // Fire update if(!m_pUpdateSink->OnFileCompleted(*m_pszFileList)) { m_pszFileList-= GetFileListIncrement(); m_cTotalBytesRead -= m_cCurrentFileBytesRead; } } return true; } /*------------------------------------------------------------------------- * FormatErrorMessage() *------------------------------------------------------------------------- * Paramters: * dwErrorCode: take a dwErrorCode and print what it means as text * */ void FormatErrorMessage(char *szBuffer, DWORD dwErrorCode) { sprintf(szBuffer,"(%d) ", dwErrorCode); FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, dwErrorCode, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), szBuffer + strlen(szBuffer), 128, NULL ); strcat(m_szLastError, " "); unsigned long dummy, size = sizeof(m_szLastError) - strlen(szBuffer) - 2; InternetGetLastResponseInfo(&dummy, szBuffer + strlen(szBuffer), &size); } /*------------------------------------------------------------------------- * DoError() *------------------------------------------------------------------------- */ void DoError(char * szFormat, ...) { if (m_szLastError[0] != 0) // don't erase over first error which can cause other errors return; m_nLastErrorCode = GetLastError(); char szMsg[sizeof(m_szLastError) - 50]; va_list pArg; va_start(pArg, szFormat); _vsnprintf(szMsg, sizeof(szMsg), szFormat, pArg); va_end(pArg); strcpy(m_szLastError, szMsg); FormatErrorMessage(m_szLastError + strlen(m_szLastError), m_nLastErrorCode); CloseDownloadFile(false); // by now GetLastError() has been cleared SetLastError(m_nLastErrorCode); // preserve this for external users if(m_pUpdateSink) m_pUpdateSink->OnError(m_szLastError); } void DoErrorInThread(char * szFormat, ...) { if (m_szErrorInThread[0] != 0) // don't erase over first error which can cause other errors return; m_nLastErrorCodeInThread = GetLastError(); char szMsg[sizeof(m_szErrorInThread) - 50]; va_list pArg; va_start(pArg, szFormat); _vsnprintf(szMsg, sizeof(szMsg), szFormat, pArg); va_end(pArg); strcpy(m_szErrorInThread, szMsg); } void KillDownload() { if (m_threadDownload) { SetEvent(m_eventKillDownload); // tell thread to exit int nAwaker = WaitForSingleObject(m_eventDownloadTerminated, 5000); // wait up to 5 seconds if (nAwaker == WAIT_TIMEOUT) { // Ideally, the thread would have it's memory instead of using the InternetSession's memory. // Then we wouldn't need to Terminate it like this. TerminateThread(m_threadDownload, 0); } CloseHandle(m_threadDownload); m_threadDownload = NULL; } } protected: char * * m_pszFileList; char * * m_pszFileListData; char m_szDestFolder[MAX_PATH]; volatile HINTERNET m_hInternetSession; volatile HINTERNET m_hFTPSession; volatile HINTERNET m_hFileConnection; volatile HANDLE m_hFile; // events for controlling download thread HANDLE m_eventResumeDownload; // ready to resume downloading because either we are starting a new file or a progress sink has finished being fired HANDLE m_eventKillDownload; // download was killed by either error or abort HANDLE m_eventDownloadTerminated; // download thread was terminated HANDLE m_eventProgress; // download progress was made, so time to fire an update event to the sink HANDLE m_eventFileCompleted; // file was completed HANDLE m_threadDownload; volatile char * m_buffer; volatile unsigned m_cBytesRead; // bytes read into m_buffer for current file volatile unsigned m_nBufferSize; // alloc-ed size of m_buffer for current file char m_szLastError[1024]; char m_szErrorInThread[MAX_PATH+100]; int m_nLastErrorCode; int m_nLastErrorCodeInThread; bool m_bAutoDisconnect; // if true, auto disconnects after D/L bool m_bAbortDownload; // if true, this class quits downloading ASAP volatile unsigned long m_cTotalBytesRead; // bytes read for all files volatile unsigned long m_cCurrentFileBytesRead; // bytes read for current file IInternetSessionSink * m_pUpdateSink; }; class CFTPSessionImpl: public CInternetSessionImpl { public: CFTPSessionImpl(IFTPSessionUpdateSink * pSink) : CInternetSessionImpl(pSink) { } virtual ~CFTPSessionImpl() { } virtual bool ConnectToSite(const char * szFTPSite, const char * szDirectory, const char * szUsername, const char * szPassword) { m_szLastError[0] = '\0'; m_hInternetSession = ::InternetOpen( "Microsoft Internet Explorer", // agent INTERNET_OPEN_TYPE_PROXY, // access "ftp-gw", // proxy server NULL, // defaults 0); // synchronous // // Connect to remote FTP server. // m_hFTPSession = ::InternetConnect( m_hInternetSession, // Handle from a previous // call to InternetOpen. szFTPSite, // Server we want to connect to INTERNET_INVALID_PORT_NUMBER, // Use appropriate port szUsername, // Username, can be NULL szPassword, // Password, can be NULL INTERNET_SERVICE_FTP, // Flag to use FTP services 0, // Flags (see SDK docs) (DWORD) this); // Context for this connection if(m_hFTPSession== NULL) { DoError("Failed to log onto FTP site (%s) : ", szFTPSite); return false; } if (!FtpSetCurrentDirectory(m_hFTPSession, szDirectory)) { DoError("Failed to enter the proper FTP directory (%s) : ", szDirectory); return false; } return true; } virtual int GetFileListIncrement() { return 1; } virtual bool InitiateDownload(const char * const * pszFileList, const char * szDestFolder, bool bDisconnectWhenDone = true, int nMaxBufferSize = 1024*1024) { bool bRet = CInternetSessionImpl::InitiateDownload(pszFileList, szDestFolder, false, nMaxBufferSize); // If progress updates are wanted if(m_pUpdateSink && *pszFileList != NULL) // Fire starting point m_pUpdateSink->OnProgress(0, *(m_pszFileList+1), 0); SetEvent(m_eventResumeDownload); // startup the download thread return bRet; } }; class CHTTPSessionImpl : public CInternetSessionImpl { public: CHTTPSessionImpl(IHTTPSessionSink * pSink) : CInternetSessionImpl(pSink) { m_hInternetSession = ::InternetOpen( "Microsoft Internet Explorer", // agent INTERNET_OPEN_TYPE_PRECONFIG, //INTERNET_OPEN_TYPE_PROXY, // access NULL,//"ftp-gw", // proxy server NULL, // defaults 0); // synchronous if (m_hInternetSession == NULL) DoError("Failed to initialize HTTP stuff."); } virtual ~CHTTPSessionImpl() { } bool ConstructionSuccess() { return m_hInternetSession != NULL; } virtual int GetFileListIncrement() { return 2; } virtual bool StartNextFile() { m_cBytesRead = 0; m_cCurrentFileBytesRead = 0; ++m_pszFileList; unsigned cTries = 0; if (*m_pszFileList) { // // Open file for download // while (!(m_hFileConnection = InternetOpenUrl(m_hInternetSession, *m_pszFileList, NULL, 0, INTERNET_FLAG_RELOAD | INTERNET_FLAG_NO_CACHE_WRITE, 0))) { cTries++; debugf("Failed to open URL(%s) for download, try #%d\n", *m_pszFileList, cTries); Sleep(500); if(cTries >= 5) { DoErrorInThread("Failed to open file for download."); return false; } } ++m_pszFileList; if (*m_pszFileList == NULL) // todo, verify the memory is ours { DoErrorInThread("FileList has bad format"); return false; } return OpenDownloadFile(); } return true; } virtual bool InitiateDownload(const char * const * pszFileList, const char * szDestFolder, int nMaxBufferSize = 1024*1024) { bool bRet = CInternetSessionImpl::InitiateDownload(pszFileList, szDestFolder, false, nMaxBufferSize); if(m_pUpdateSink && *(m_pszFileList+2) != NULL) // Fire starting point m_pUpdateSink->OnProgress(0, *(m_pszFileList+2), 0); SetEvent(m_eventResumeDownload); // startup the download thread return bRet; } }; IHTTPSession * CreateHTTPSession(IHTTPSessionSink * pUpdateSink /*= NULL*/) { CHTTPSessionImpl * pNew = new CHTTPSessionImpl(pUpdateSink); if (pNew && pNew->ConstructionSuccess()) return pNew; else return NULL; } IFTPSession * CreateFTPSession(IFTPSessionUpdateSink * pUpdateSink /*= NULL*/) { CFTPSessionImpl * pNew = new CFTPSessionImpl(pUpdateSink); return pNew; }