/*
 *      Copyright (C) 2014-2015 Jean-Luc Barriere
 *
 *  This library is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU Lesser General Public License as published
 *  by the Free Software Foundation; either version 3, or (at your option)
 *  any later version.
 *
 *  This library is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 *  GNU Lesser General Public License for more details.
 *
 *  You should have received a copy of the GNU Lesser General Public License
 *  along with this library; see the file COPYING.  If not, write to
 *  the Free Software Foundation, 51 Franklin Street, Fifth Floor, Boston,
 *  MA 02110-1301 USA
 *  http://www.gnu.org/copyleft/gpl.html
 *
 */

#include "wsresponse.h"
#include "securesocket.h"
#include "compressor.h"
#include "debug.h"

#include <cstdlib>  // for atol
#include <cstdio>
#include <cstring>

#define HTTP_TOKEN_MAXSIZE    20
#define HTTP_HEADER_MAXSIZE   4000
#define RESPONSE_BUFFER_SIZE  4000
#define CHUNK_MAX_SIZE        0x1FFFF

using namespace NSROOT;

void WSResponse::init(const WSRequest &request, int maxRedirs, bool trustedLocation, bool followAny)
{
  p = new _response(request);
  while (0 < maxRedirs--)
  {
    int status = p->GetStatusCode();
    if (status == 301 || status == 302)
    {
      // handle redirection
      URIParser uri(p->Redirection());
      bool trusted = (uri.Scheme() && strncmp("https", uri.Scheme(), 5) == 0);
      if (
          /* relative */ !uri.Host() ||
          /* same origin */ (request.GetServer() == uri.Host() && (!trustedLocation || trusted)) ||
          /* follow any */ (followAny && (!trustedLocation || trusted))
          )
      {
        DBG(DBG_DEBUG, "%s: (%d) LOCATION = %s\n", __FUNCTION__, p->GetStatusCode(), p->Redirection().c_str());
        WSRequest redir(request, uri);
        delete p;
        p = new _response(redir);
        continue;
      }
    }
    break;
  }
}

WSResponse::~WSResponse()
{
  if (p)
    delete p;
  p = nullptr;
}

bool WSResponse::ReadHeaderLine(NetSocket *socket, const char *eol, std::string& line, size_t *len)
{
  char buf[RESPONSE_BUFFER_SIZE];
  const char *s_eol;
  int p = 0, p_eol = 0, l_eol;
  size_t l = 0;

  if (eol != nullptr)
    s_eol = eol;
  else
    s_eol = "\n";
  l_eol = strlen(s_eol);

  line.clear();
  do
  {
    if (socket->ReceiveData(&buf[p], 1) > 0)
    {
      if (buf[p++] == s_eol[p_eol])
      {
        if (++p_eol >= l_eol)
        {
          buf[p - l_eol] = '\0';
          line.append(buf);
          l += p - l_eol;
          break;
        }
      }
      else
      {
        p_eol = 0;
        if (p > (RESPONSE_BUFFER_SIZE - 2 - l_eol))
        {
          buf[p] = '\0';
          line.append(buf);
          l += p;
          p = 0;
        }
      }
    }
    else
    {
      /* No EOL found until end of data */
      *len = l;
      return false;
    }
  }
  while (l < HTTP_HEADER_MAXSIZE);

  *len = l;
  return true;
}

WSResponse::_response::_response(const WSRequest &request)
: m_socket(nullptr)
, m_successful(false)
, m_statusCode(0)
, m_serverInfo()
, m_etag()
, m_location()
, m_contentTypeStr()
, m_contentType(WS_CTYPE_None)
, m_contentEncoding(WS_CENCODING_None)
, m_contentChunked(false)
, m_contentLength(0)
, m_consumed(0)
, m_chunkBuffer(nullptr)
, m_chunkPtr(nullptr)
, m_chunkEOR(nullptr)
, m_chunkEnd(nullptr)
, m_decoder(nullptr)
{
  if (request.IsSecureURI())
    m_socket = SSLSessionFactory::Instance().NewClientSocket();
  else
    m_socket = new TcpSocket();
  if (!m_socket)
    DBG(DBG_ERROR, "%s: create socket failed\n", __FUNCTION__);
  else if (m_socket->Connect(request.GetServer().c_str(), request.GetPort(), SOCKET_RCVBUF_MINSIZE))
  {
    m_socket->SetReadAttempt(6); // 60 sec to hang up
    if (SendRequest(request) && GetResponse())
    {
      if (m_statusCode < 200)
        DBG(DBG_WARN, "%s: status %d\n", __FUNCTION__, m_statusCode);
      else if (m_statusCode < 300)
        m_successful = true;
      else if (m_statusCode < 400)
        m_successful = false;
      else if (m_statusCode < 500)
        DBG(DBG_ERROR, "%s: bad request (%d)\n", __FUNCTION__, m_statusCode);
      else
        DBG(DBG_ERROR, "%s: server error (%d)\n", __FUNCTION__, m_statusCode);
    }
    else
      DBG(DBG_ERROR, "%s: invalid response\n", __FUNCTION__);
  }
}

WSResponse::_response::~_response()
{
  if (m_decoder)
    delete m_decoder;
  m_decoder = nullptr;
  if (m_chunkBuffer)
    delete [] m_chunkBuffer;
  m_chunkBuffer = m_chunkPtr = m_chunkEOR = m_chunkEnd = nullptr;
  if (m_socket)
    delete m_socket;
  m_socket = nullptr;
}

bool WSResponse::_response::SendRequest(const WSRequest &request)
{
  std::string msg;

  request.MakeMessage(msg);
  DBG(DBG_PROTO, "%s: %s\n", __FUNCTION__, msg.c_str());
  if (!m_socket->SendData(msg.c_str(), msg.size()))
  {
    DBG(DBG_ERROR, "%s: failed (%d)\n", __FUNCTION__, m_socket->GetErrNo());
    return false;
  }
  return true;
}

bool WSResponse::_response::GetResponse()
{
  size_t len;
  std::string strread;
  char token[HTTP_TOKEN_MAXSIZE + 1];
  int n = 0, token_len = 0;
  bool ret = false;

  token[0] = 0;
  while (WSResponse::ReadHeaderLine(m_socket, WS_CRLF, strread, &len))
  {
    const char *line = strread.c_str(), *val = nullptr;
    int value_len = 0;

    DBG(DBG_PROTO, "%s: %s\n", __FUNCTION__, line);
    /*
     * The first line of a Response message is the Status-Line, consisting of
     * the protocol version followed by a numeric status code and its associated
     * textual phrase, with each element separated by SP characters.
     */
    if (++n == 1)
    {
      int status;
      if (len > 5 && 0 == memcmp(line, "HTTP", 4) && 1 == sscanf(line, "%*s %d", &status))
      {
        /* We have received a valid feedback */
        m_statusCode = status;
        ret = true;
      }
      else
      {
        /* Not a response header */
        return false;
      }
    }

    if (len == 0)
    {
      /* End of header */
      break;
    }

    /*
     * Header fields can be extended over multiple lines by preceding each
     * extra line with at least one SP or HT.
     */
    if ((line[0] == ' ' || line[0] == '\t') && token_len)
    {
      /* Append value of previous token */
      val = line;
    }
      /*
       * Each header field consists of a name followed by a colon (":") and the
       * field value. Field names are case-insensitive. The field value MAY be
       * preceded by any amount of LWS, though a single SP is preferred.
       */
    else if ((val = strchr(line, ':')))
    {
      int p;
      if ((token_len = val - line) > HTTP_TOKEN_MAXSIZE)
        token_len = HTTP_TOKEN_MAXSIZE;
      for (p = 0; p < token_len; ++p)
        token[p] = toupper(line[p]);
      token[token_len] = 0;
      value_len = len - (val - line + 1);
      while (value_len > 0 && (*(++val) == ' ' || *val == '\t')) --value_len;
      m_headers.push_front(std::make_pair(token, ""));
    }
    else
    {
      /* Unknown syntax! Close previous token */
      token_len = 0;
      token[token_len] = 0;
    }

    if (token_len && val)
    {
      m_headers.front().second.append(val);
      switch (ws_header_from_upperstr(token))
      {
        case WS_HEADER_ETag:
          m_etag.assign(val);
          break;
        case WS_HEADER_Server:
          m_serverInfo.assign(val);
          break;
        case WS_HEADER_Location:
          m_location.assign(val);
          break;
        case WS_HEADER_Content_Type:
          m_contentTypeStr.assign(val);
          m_contentType = ws_ctype_from_str(val);
          break;
        case WS_HEADER_Content_Length:
          m_contentLength = atol(val);
          break;
        case WS_HEADER_Content_Encoding:
          m_contentEncoding = ws_cencoding_from_str(val);
          if (m_contentEncoding == WS_CENCODING_UNKNOWN)
            DBG(DBG_ERROR, "%s: unsupported content encoding (%s)\n", __FUNCTION__, val);
          break;
        case WS_HEADER_Transfer_Encoding:
          if (value_len > 6 && memcmp(val, "chunked", 7) == 0)
            m_contentChunked = true;
          break;
        default:
          break;
      }
    }
  }

  return ret;
}

size_t WSResponse::_response::ReadChunk(void *buf, size_t buflen)
{
  size_t s = 0;
  if (m_contentChunked)
  {
    // no more pending byte in chunk buffer
    if (m_chunkPtr >= m_chunkEnd)
    {
      // process next chunk
      if (m_chunkBuffer)
        delete [] m_chunkBuffer;
      m_chunkBuffer = m_chunkPtr = m_chunkEOR = m_chunkEnd = nullptr;
      std::string strread;
      size_t len = 0;
      while (WSResponse::ReadHeaderLine(m_socket, WS_CRLF, strread, &len) && len == 0);
      DBG(DBG_PROTO, "%s: chunked data (%s)\n", __FUNCTION__, strread.c_str());
      std::string chunkStr("0x0");
      uint32_t chunkSize;
      if (!strread.empty() && sscanf(chunkStr.append(strread).c_str(), "%x", &chunkSize) == 1 && chunkSize > 0)
      {
        // check chunk-size overflow
        if (chunkSize > CHUNK_MAX_SIZE)
        {
          DBG(DBG_ERROR, "%s: chunk-size overflow (req=%u) (max=%u)\n", __FUNCTION__, chunkSize, (unsigned)CHUNK_MAX_SIZE);
          return 0;
        }
        if (!(m_chunkBuffer = new char[chunkSize]))
          return 0;
        m_chunkPtr = m_chunkEOR = m_chunkBuffer;
        m_chunkEnd = m_chunkBuffer + chunkSize;
      }
      else
        return 0; // that's the end of chunks
    }
    // fill chunk buffer
    if (m_chunkPtr >= m_chunkEOR)
    {
      // ask for new data to fill in the chunk buffer
      // fill at last read position and until to the end
      m_chunkEOR += m_socket->ReceiveData(m_chunkEOR, m_chunkEnd - m_chunkEOR);
    }
    if ((s = m_chunkEOR - m_chunkPtr) > buflen)
      s = buflen;
    memcpy(buf, m_chunkPtr, s);
    m_chunkPtr += s;
    m_consumed += s;
  }
  return s;
}

int WSResponse::_response::SocketStreamReader(void *hdl, void *buf, int sz)
{
  _response *resp = static_cast<_response*>(hdl);
  if (resp == nullptr)
    return 0;
  size_t s = 0;
  // let read on unknown length
  if (!resp->m_contentLength)
    s = resp->m_socket->ReceiveData(buf, sz);
  else if (resp->m_contentLength > resp->m_consumed)
  {
    size_t len = resp->m_contentLength - resp->m_consumed;
    s = resp->m_socket->ReceiveData(buf, len > (size_t)sz ? (size_t)sz : len);
  }
  resp->m_consumed += s;
  return s;
}

int WSResponse::_response::ChunkStreamReader(void *hdl, void *buf, int sz)
{
  _response *resp = static_cast<_response*>(hdl);
  return (resp == nullptr ? 0 : resp->ReadChunk(buf, sz));
}

size_t WSResponse::_response::ReadContent(char* buf, size_t buflen)
{
  size_t s = 0;
  if (!m_contentChunked)
  {
    if (m_contentEncoding == WS_CENCODING_None)
    {
      // let read on unknown length
      if (!m_contentLength)
        s = m_socket->ReceiveData(buf, buflen);
      else if (m_contentLength > m_consumed)
      {
        size_t len = m_contentLength - m_consumed;
        s = m_socket->ReceiveData(buf, len > buflen ? buflen : len);
      }
      m_consumed += s;
    }
    else if (m_contentEncoding == WS_CENCODING_Gzip || m_contentEncoding == WS_CENCODING_Deflate)
    {
      if (m_decoder == nullptr)
        m_decoder = new Decompressor(&SocketStreamReader, this);
      if (m_decoder->HasOutputData())
        s = m_decoder->ReadOutput(buf, buflen);
      if (s == 0 && !m_decoder->IsCompleted())
      {
        if (m_decoder->HasStreamError())
          DBG(DBG_ERROR, "%s: decoding failed: stream error\n", __FUNCTION__);
        else if (m_decoder->HasBufferError())
          DBG(DBG_ERROR, "%s: decoding failed: buffer error\n", __FUNCTION__);
        else
          DBG(DBG_ERROR, "%s: decoding failed\n", __FUNCTION__);
      }
    }
  }
  else
  {
    if (m_contentEncoding == WS_CENCODING_None)
    {
      s = ReadChunk(buf, buflen);
    }
    else if (m_contentEncoding == WS_CENCODING_Gzip || m_contentEncoding == WS_CENCODING_Deflate)
    {
      if (m_decoder == nullptr)
        m_decoder = new Decompressor(&ChunkStreamReader, this);
      if (m_decoder->HasOutputData())
        s = m_decoder->ReadOutput(buf, buflen);
      if (s == 0 && !m_decoder->IsCompleted())
      {
        if (m_decoder->HasStreamError())
          DBG(DBG_ERROR, "%s: decoding failed: stream error\n", __FUNCTION__);
        else if (m_decoder->HasBufferError())
          DBG(DBG_ERROR, "%s: decoding failed: buffer error\n", __FUNCTION__);
        else
          DBG(DBG_ERROR, "%s: decoding failed\n", __FUNCTION__);
      }
    }
  }
  return s;
}

bool WSResponse::_response::GetHeaderValue(const std::string& header, std::string& value)
{
  for (HeaderList::const_iterator it = m_headers.begin(); it != m_headers.end(); ++it)
  {
    if (it->first != header)
      continue;
    value.assign(it->second);
    return true;
  }
  return false;
}
