#
# Copyright (C) 2006 Chris Halls <halls@debian.org>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

"""
A client is created each time a file should be downloaded/streamed.
clients can be generated by:

apt clients (HttpRequestClient)
files that should be uncompressed such as Packages.bz2 (UncompressClient)

CacheEntry objects will notify subscribed Clients of changes in state
"""

import os, re, urlparse, urllib, gzip, bz2
from StringIO import StringIO

from twisted.web import http
from twisted.internet import reactor

from misc import log

class HttpRequestClient(http.Request):
    """
    Request generated from apt clients via http protocol

    Each new request from connected clients generates a new instance of this
    class, and process() is called.
    """
    if_modified_since = None # If-modified-since time requested by client
    backend = None           # Backend for this request
    backendServer = None     # Current server to be tried
    cacheEntry = None        # Cache entry for file requested
    is_real_client = True       # This class represents a real client, not a postprocessor
    def __init__(self, channel, queued):
        log.debug("New Request, queued=%s" % (queued),'HttpRequestClient');
        self.factory=channel.factory
        http.Request.__init__(self, channel, queued)

    def process(self):
        """
        Each new request begins processing here
        """
        log.debug("Processing request for: %s" % (self.uri), 'HttpRequestClient')
        self.uri = self.clean_path(self.uri)

        match = re.search('[^][a-zA-Z0-9~,.+%:;@#?{}()$/_-]', self.uri)
        if match:
            log.err("Invalid characters found in filename at position %s" % (match.start()))
            self.finishCode(http.FORBIDDEN, "Invalid character in filename at position %s" % (match.start()))
            return

        if_modified_since = self.getHeader('if-modified-since')
        if if_modified_since != None:
            self.if_modified_since = http.stringToDatetime(
                    if_modified_since)

        if self.uri[0] != '/':
            log.err("Request must include at least one '/'")
            self.finishCode(http.FORBIDDEN, "Request must include at least one '/'")
            return

        backendName = self.uri[1:].split('/')[0]
        log.debug("Request: %s %s backend=%s uri=%s"
                    % (self.method, self.uri, backendName, self.uri),'HttpRequestClient')

        if self.method != 'GET':
            #we currently only support GET
            log.err("abort - method not implemented", 'HttpRequestClient')
            self.finishCode(http.NOT_IMPLEMENTED)
            return

        if re.search('/\.\./', self.uri):
            log.err("/../ in simplified uri ("+self.uri+")", 'HttpRequestClient')
            self.finishCode(http.FORBIDDEN)
            return

        self.backend = self.factory.getBackend(backendName)
        if self.backend is None:
            log.err("backend %s not found" % (backendName), 'HttpRequestClient')
            self.finishCode(http.NOT_FOUND, "NON-EXISTENT BACKEND")
            return None

        log.debug("backend: %s %s" % (self.backend.base, self.backend.uris), 'HttpRequestClient')

        elements = self.uri.split('/', 2)
        if len(elements) < 3:
            log.err("abort - too few slashes in URI %s" % (self.uri), 'Request')
            self.finishCode(http.FORBIDDEN, 'too few slashes in URI %s' % (self.uri))
            return

        backend_path = elements[2]
        self.cacheEntry = self.backend.get_cache_entry(backend_path)

        if not self.cacheEntry.filetype:
            log.err("abort - unknown extension for file %s" % (backend_path), 'HttpRequestClient')
            self.finishCode(http.FORBIDDEN, 'File not found - unknown extension')
            return

        self.setHeader('content-type', self.cacheEntry.filetype.contype)

        if os.path.isdir(self.cacheEntry.file_path):
            log.err("abort - Directory listing not allowed", 'HttpRequestClient')
            self.finishCode(http.FORBIDDEN, 'Directory listing not permitted')
            return

        self.cacheEntry.add_request(self)

    def clean_path(self, uri):
        # Clean up URL given
        scheme, netloc, path, params, query, fragment = urlparse.urlparse(uri)
        unquoted_path = urllib.url2pathname(path)
        return os.path.normpath(unquoted_path)

    def not_modified(self):
        """
        File is not modified - send http hit
        """
        self.setHeader("content-length", 0)
        self.finishCode(http.NOT_MODIFIED, 'File is up to date')

    def start_streaming(self, size, mtime):
        """
        Prepare client to stream file
        Return false if streaming is not necessary (i.e. cache hit)
        """
        if self.if_modified_since is None or self.if_modified_since < mtime:
            log.debug("start_streaming size=%s mtime=%s if_modified_since=%s" % (size, mtime, self.if_modified_since) , 'HttpRequestClient')
            if mtime is not None:
                self.setHeader('last-modified', http.datetimeToString(mtime))
            if size is not None:
                self.setHeader('content-length', size)
            self.setResponseCode(http.OK, 'Streaming file')
            return True
        else:
            log.debug("file not modified: mtime=%s if_modified_since=%s" % (mtime, self.if_modified_since) , 'HttpRequestClient')
            self.not_modified()
            return False

    def finishCode(self, responseCode, message=None):
        "Finish the request with a status code and no streamed data"
        log.debug("finishCode: %s, %s" % (responseCode, message), 'HttpRequestClient')
        self.setResponseCode(responseCode, message)
        self.setHeader("content-type", "text/html")
        self.write("<html><head><title>ERROR %d</title></head><body>ERROR %d - %s</body></html>\n" % (responseCode, responseCode, message))
        self.finish()

    def finish(self):
        "Finish request after streaming"
        log.debug("finish. fileno:%s uri:%s" % (self.getFileno(), self.uri) , 'HttpRequestClient')
        try:
            http.Request.finish(self)
        except Exception, e:
            log.debug("Unexpected error finishing http request: %s" % (e), 'HttpRequestClient')
            import traceback
            traceback.print_stack()

        if self.cacheEntry:
            reactor.callLater(0, self.cacheEntry.remove_request, self)
            self.cacheEntry = None

    def connectionLost(self, reason=None):
        """
        The connection with the client was lost, remove this request from its
        Fetcher.
        """
        log.debug("connectionLost" , 'HttpRequestClient')
        if self.cacheEntry:
            reactor.callLater(0, self.cacheEntry.remove_request, self)
            self.cacheEntry = None
        #self.finish()

    def getFileno(self):
        """
        Get identifier which is unique per apt client
        """
        try:
            fileno = self.channel.transport.fileno()
        except:
            fileno = -1
            log.msg("could not get transport's file descriptor", 'HttpRequestClient')
        return fileno

class UncompressClient:
    """
    Request generated from apt clients via http protocol

    Each new request from connected clients generates a new instance of this
    class, and process() is called.
    """

    logname = 'UncompressClient' # Name for log messages
    if_modified_since = None
    can_uncompress_stream = True # We can deal with chunks
        
    class FilenameError(Exception):
        def __init__(self, filename, msg):
            self.filename = filename
            self.msg = msg
        def __str__(self):
            return("Error in filename (%s): %s" % (self.filename, self.msg))

    def __init__(self, compressedCacheEntry):
        log.debug("New UncompressClient for %s" % (compressedCacheEntry),self.logname);
        self.source = compressedCacheEntry
        self.source.add_request(self)
	self.dest = None
    
    def get_dest_filename(self, path):
        extlen = len(self.ext)
        if len(path) < extlen:
            raise self.FilenameError(path, "Filename is too short")
        if path[-extlen:] != self.ext:
            raise self.FilenameError(path, "Filename does not end in '%s'" % (self.ext))
        return path[:-extlen]
        
    def not_modified(self):
    	pass

    def getFileno(self):
	    return -1

    def finishCode(self, responseCode, message=None):
        "Request aborted"
	if self.dest:
            self.dest.download_failure(responseCode, message)
        self.remove_from_cache_entry()

    def finish(self):
        self.disconnect()
        
    def disconnect(self):
	if self.dest:
            self.dest.download_data_end()
        self.remove_from_cache_entry()
        
    def remove_from_cache_entry(self):
        if self.source:
            reactor.callLater(0, self.source.remove_request, self)
            self.source = None

    def start_streaming(self, size, mtime):
        backend_path = self.get_dest_filename(self.source.path)
        self.dest = self.source.backend.get_cache_entry(backend_path)
        self.dest.stat_file()
        if self.dest.file_mtime is not None and mtime < self.dest.file_mtime:
            log.debug("Skipping decompression of file (%s mtime=%s), destination file (%s, mtime=%s) is newer" 
                                % (self.source.path, mtime, self.dest.filename, self.dest.file_mtime), self.logname)
            self.disconnect()
            return False
        
        log.debug("Decompressing %s -> %s" % (self.source.path, self.dest.filename), self.logname)
        self.dest.init_tempfile() # Open file for streaming
        self.dest.download_started(None, size, mtime)
        if self.source.state == self.source.STATE_SENDFILE:
            self.finish()
            return False # We don't want file to be streamed directly
        
        if self.can_uncompress_stream:
            return True # Send us chunks of data please
        else:
            return False

    def write(self, data):
        log.debug("Decompressing %s bytes (%s)" % (len(data), self.source.cache_path), self.logname)
        uncompressed = self.uncompress(data)
        self.dest.download_data_received(uncompressed)
        
class UncompressClientFile(UncompressClient):
    can_uncompress_stream = False
    def __init__(self, compressedCacheEntry):
        UncompressClient.__init__(self, compressedCacheEntry)
    def finish(self):
        self.uncompressStart()
    def write(self,data):
        pass # Nothing to do
        
class Bz2UncompressClient(UncompressClient):
    """
    Uncompress file using bzip2 (e.g. Packages.bz2)
    """
    logname = 'Bz2UncompressClient'
    ext = '.bz2'
    
    def __init__(self, compressedCacheEntry):
        self.decompressor = bz2.BZ2Decompressor()
        UncompressClient.__init__(self, compressedCacheEntry)
    def uncompress(self, data):
        return self.decompressor.decompress(data)

# Note: this class does not work because gzip.GzipFile wants to seek
# in the file
#class GzUncompressClient(UncompressClient):
    #"""
    #Uncompress file using gzip (e.g. Packages.gz)
    #"""
    #logname = 'GzUncompressClient'
    #ext = '.gz'
    
    #def __init__(self, compressedCacheEntry):
        #self.string = StringIO()
        #self.unzipper = gzip.GzipFile(compresslevel=0, fileobj = self.string, mode='r')
        #UncompressClient.__init__(self, compressedCacheEntry)
    #def uncompress(self, data):
        #buflen = len(data)
        #self.string.write(data)
        #self.string.seek(-buflen, 1)
        #buf = self.unzipper.read()
        #return buf

class GzUncompressClient(UncompressClientFile):
    """
    Uncompress file using gzip (e.g. Packages.gz)
    """
    logname = 'GzUncompressClient'
    ext = '.gz'
    
    def __init__(self, compressedCacheEntry):
        UncompressClientFile.__init__(self, compressedCacheEntry)

        
    def uncompressStart(self):
        # Stream all data
        log.debug("Uncompressing file %s" % (self.source.file_path), self.logname)
        self.inputFile = open(self.source.file_path)
        self.unzipper = gzip.GzipFile(compresslevel=0, fileobj = self.inputFile, mode='r')
        reactor.callLater(0, self.uncompressFileChunk)        
        self.uncompressedLen = 0
    
    def uncompressFileChunk(self):
        uncompressed = self.unzipper.read()
        chunkLen = len(uncompressed)
        if chunkLen == 0:
            self.uncompressDone()
            return
        
        self.uncompressedLen += chunkLen
        self.dest.download_data_received(uncompressed)
        reactor.callLater(0, self.uncompressFileChunk)        
                
    def uncompressDone(self):
        log.debug("Uncompressed file %s, len=%s" % (self.dest.filename, self.uncompressedLen), self.logname)
        self.unzipper.close()
        UncompressClient.finish(self)        
