#!/usr/bin/env python3

import datetime
import getopt
import io
import mimetypes
import os
import pwd
import select
import socket
import sys
import traceback
import urllib.parse


# Http server based on recipes 511453,511454 from code.activestate.com by
# Pierre Quentel Added support for indexes, access tests, proper handle of
# SystemExit exception, fixed couple of errors and vulnerabilities, getopt,
# lockfiles, daemonize etc. by Jakub Kruszona-Zawadzki

# The dictionary holding one client handler for each connected client
# key = client socket, value = instance of (a subclass of) ClientHandler
CLIENT_HANDLERS = {}

# =======================================================================
# The server class. Creating an instance starts a server on the specified
# host and port
# =======================================================================


class Server:
    def __init__(self, host='localhost', port=80):
        if host == 'any':
            host = ''
        self.host, self.port = host, port
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.socket.setblocking(0)
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.socket.bind((host, port))
        self.socket.listen(50)

# =====================================================================
# Generic client handler. An instance of this class is created for each
# request sent by a client to the server
# =====================================================================


class ClientHandler(object):
    blocksize = 2048

    def __init__(self, server, client_socket, client_address):
        self.server = server
        self.client_address = client_address
        self.client_socket = client_socket
        self.client_socket.setblocking(0)
        self.host = socket.getfqdn(client_address[0])
        #self.incoming_bytes = b''
        self.incoming = b''  # receives incoming data
        self.outgoing = b''
        self.writable = False
        self.close_when_done = True
        self.response = None

    def handle_error(self):
        self.close()

    def handle_read(self):
        """Reads the data received"""
        try:
            buff = self.client_socket.recv(1024)
            if not buff:  # the connection is closed
                self.close()
            # buffer the data in self.incoming
            self.incoming += buff  # .write(buff)
            #self.incoming_bytes += buff
            self.process_incoming()
        except socket.error:
            self.close()

    def process_incoming(self):
        """Test if request is complete ; if so, build the response
        and set self.writable to True"""
        #self.incoming = self.incoming_bytes.decode('utf-8')
        if not self.request_complete():
            return
        self.response = self.make_response()
        self.outgoing = b''
        self.writable = True

    def request_complete(self):
        """Return True if the request is complete, False otherwise
        Override this method in subclasses"""
        return True

    def make_response(self):
        """Return the list of strings or file objects whose content will
        be sent to the client
        Override this method in subclasses"""
        return [b"xxx"]

    def handle_write(self):
        """Send (a part of) the response on the socket
        Finish the request if the whole response has been sent
        self.response is a list of strings or file objects
        """
        if self.outgoing == b'' and self.response:
            if isinstance(self.response[0], bytes):
                self.outgoing = self.response.pop(0)
            else:
                self.outgoing = self.response[0].read(self.blocksize)  # pylint: disable=E1101
                if not self.outgoing:
                    self.response.pop(0)
        if self.outgoing:
            try:
                sent = self.client_socket.send(self.outgoing)
            except socket.error:
                self.close()
                return
            if sent < len(self.outgoing):
                self.outgoing = self.outgoing[sent:]
            else:
                self.outgoing = b''
        if self.outgoing == b'' and not self.response:
            if self.close_when_done:
                self.close()  # close socket
            else:
                # reset for next request
                self.writable = False
                #self.incoming_bytes = b''
                self.incoming = b''

    def close(self):
        del CLIENT_HANDLERS[self.client_socket]
        self.client_socket.close()

# ============================================================================
# Main loop, calling the select() function on the sockets to see if new
# clients are trying to connect, if some clients have sent data and if those
# for which the response is complete are ready to receive it
# For each event, call the appropriate method of the server or of the instance
# of ClientHandler managing the dialog with the client : handle_read() or
# handle_write()
# ============================================================================


def loop(server, handler, timeout=30):
    try:
        while True:
            k = list(CLIENT_HANDLERS.keys())
            # w = sockets to which there is something to send
            # we must test if we can send data
            w = [cl for cl in CLIENT_HANDLERS if CLIENT_HANDLERS[cl].writable]
            # the heart of the program ! "r" will have the sockets that have sent
            # data, and the server socket if a new client has tried to connect
            r, w, e = select.select(k + [server.socket], w, k, timeout)
            for e_socket in e:
                CLIENT_HANDLERS[e_socket].handle_error()
            for r_socket in r:
                if r_socket is server.socket:
                    # server socket readable means a new connection request
                    try:
                        client_socket, client_address = server.socket.accept()
                        CLIENT_HANDLERS[client_socket] = handler(
                            server, client_socket, client_address)
                    except socket.error:
                        pass
                else:
                    # the client connected on r_socket has sent something
                    CLIENT_HANDLERS[r_socket].handle_read()
            w = set(w) & set(CLIENT_HANDLERS.keys())  # remove deleted sockets
            for w_socket in w:
                CLIENT_HANDLERS[w_socket].handle_write()
    except KeyboardInterrupt:
        pass


# =============================================================
# An implementation of the HTTP protocol, supporting persistent
# connections and CGI
# =============================================================

class HTTP(ClientHandler):
    # parameters to override if necessary
    root = os.getcwd()                              # the directory to serve files from
    # index files for directories
    index_files = [b'index.cgi', b'index.html']
    # print logging info for each request ?
    logging = True
    # size of blocks to read from files and send
    blocksize = 2 << 16

    def __init__(self, server, client_socket, client_address):
        super(HTTP, self).__init__(server, client_socket, client_address)
        self.requestline = None
        self.protocol = None
        self.url = None
        self.file_name = None
        self.rest = None
        self.headers = None
        self.mngt_method = None
        self.path = None
        self.method = None

    def request_complete(self):
        """In the HTTP protocol, a request is complete if the "end of headers"
        sequence ('\r\n\r\n') has been received
        If the request is POST, stores the request body in a StringIO before
        returning True"""
        terminator = self.incoming.find(b'\r\n\r\n')
        if terminator == -1:
            return False
        lines = self.incoming[:terminator].split(b'\r\n')
        self.requestline = lines[0]
        try:
            self.method, self.url, self.protocol = lines[0].strip().split()
            if (not self.protocol.startswith(b"HTTP/1") or
                    (self.protocol[7] != ord('0') and self.protocol[7] != ord('1')) or
                    len(self.protocol) != 8):
                self.method = None
                self.protocol = b"HTTP/1.1"
                return True
        except Exception:
            self.method = None
            self.protocol = b"HTTP/1.1"
            return True
        # put request headers in a dictionary
        self.headers = {}
        for line in lines[1:]:
            k, v = line.split(b':', 1)
            self.headers[k.lower().strip()] = v.strip()
        # persistent connection
        close_conn = self.headers.get(b"connection", b"")
        if (self.protocol == b"HTTP/1.1" and close_conn.lower() == b"keep-alive"):
            self.close_when_done = False
        # parse the url
        _, _, path, params, query, fragment = urllib.parse.urlparse(
            self.url)
        self.path, self.rest = path, (params, query, fragment)

        if self.method == b'POST':
            # for POST requests, read the request body
            # its length must be specified in the content-length header
            content_length = int(self.headers.get(b'content-length', 0))
            body = self.incoming[terminator + 4:]
            # request is incomplete if not all message body received
            if len(body) < content_length:
                return False
            f_body = io.BytesIO(body)
            f_body.seek(0)
            sys.stdin = f_body  # compatibility with CGI

        return True

    def make_response(self):
        """Build the response: a list of strings or files"""
        try:
            if self.method is None:  # bad request
                return self.err_resp(400, b'Bad request : %s' % self.requestline)
            resp_headers, resp_file = b'', None
            if self.method not in [b'GET', b'POST', b'HEAD']:
                return self.err_resp(501, b'Unsupported method (%s)' % self.method)
            else:
                file_name = self.file_name = self.translate_path()
                if not file_name.startswith(HTTP.root.encode('utf-8') + os.path.sep.encode('utf-8')) and not file_name == HTTP.root.encode('utf-8'):
                    return self.err_resp(403, b'Forbidden')
                elif not os.path.exists(file_name):
                    return self.err_resp(404, b'File not found')
                elif self.managed():
                    response = self.mngt_method()
                elif not os.access(file_name, os.R_OK):
                    return self.err_resp(403, b'Forbidden')
                else:
                    fstatdata = os.stat(file_name)
                    if (fstatdata.st_mode & 0o170000) == 0o040000:    # directory
                        for index in self.index_files:
                            if (os.path.exists(file_name + b'/' + index) and
                                    os.access(file_name + b'/' + index, os.R_OK)):
                                return self.redirect_resp(index)
                    if (fstatdata.st_mode & 0o170000) != 0o100000:
                        return self.err_resp(403, b'Forbidden')
                    ext = os.path.splitext(file_name)[1]
                    c_type = mimetypes.types_map.get(ext.decode('utf-8'), 'text/plain').encode('utf-8')
                    resp_line = b"%s 200 Ok\r\n" % self.protocol
                    size = fstatdata.st_size
                    resp_headers = b"Content-Type: %s\r\n" % c_type
                    resp_headers += b"Content-Length: %d\r\n" % size
                    resp_headers += b'\r\n'
                    if self.method == b"HEAD":
                        resp_string = resp_line + resp_headers
                    elif size > HTTP.blocksize:
                        resp_string = resp_line + resp_headers
                        resp_file = open(file_name, 'rb')
                    else:
                        resp_string = resp_line + resp_headers + \
                            open(file_name, 'rb').read()
                    response = [resp_string]
                    if resp_file:
                        response.append(resp_file)
            self.log(200)
            return response
        except Exception:
            return self.err_resp(500, b'Internal Server Error')

    def translate_path(self):
        """Translate URL path into a path in the file system"""
        return os.path.realpath(os.path.join(HTTP.root.encode('utf-8'), *self.path.split(b'/')))

    def managed(self):
        """Test if the request can be processed by a specific method
        If so, set self.mngt_method to the method used
        This implementation tests if the script is in a cgi directory"""
        if self.is_cgi():
            self.mngt_method = self.run_cgi
            return True
        return False

    def is_cgi(self):
        """Test if url points to cgi script"""
        if self.path.endswith(b".cgi"):
            return True
        return False

    class StrWritableBytesIO(io.BytesIO):
        def write(self, str_or_byteslike):
            if isinstance(str_or_byteslike, str):
                return super().write(str_or_byteslike.encode('utf-8'))
            else:
                return super().write(str_or_byteslike)

    def run_cgi(self):
        if not os.access(self.file_name, os.X_OK):
            return self.err_resp(403, b'Forbidden')
        # set CGI environment variables
        self.make_cgi_env()
        save_stdout = sys.stdout
        output_buffer = self.StrWritableBytesIO()
        sys.stdout = output_buffer
        # run the script
        try:
            exec(compile(open(self.file_name, "rb").read(), self.file_name, 'exec'), {})
        except SystemExit:
            pass
        except Exception:
            output_buffer = self.StrWritableBytesIO()
            output_buffer.write(b"Content-type:text/plain\r\n\r\n")
            traceback.print_exc(file=output_buffer)
        sys.stdout = save_stdout  # restore sys.stdout
        response = output_buffer.getvalue()
        if self.method == b"HEAD":
            # for HEAD request, don't send message body even if the script
            # returns one (RFC 3875)
            head_lines = []
            for line in response.split(b'\n'):
                if not line:
                    break
                head_lines.append(line)
            response = b'\n'.join(head_lines)
        # close connection in case there is no content-length header
        self.close_when_done = True
        resp_line = b"%s 200 Ok\r\n" % self.protocol
        return [resp_line + response]

    def make_cgi_env(self):
        """Set CGI environment variables"""
        env = {}
        env['SERVER_SOFTWARE'] = "AsyncServer"
        env['SERVER_NAME'] = "AsyncServer"
        env['GATEWAY_INTERFACE'] = 'CGI/1.1'
        env['DOCUMENT_ROOT'] = HTTP.root
        env['SERVER_PROTOCOL'] = "HTTP/1.1"
        env['SERVER_PORT'] = str(self.server.port)

        env['REQUEST_METHOD'] = self.method.decode('utf-8')
        env['REQUEST_URI'] = self.url.decode('utf-8')
        env['PATH_TRANSLATED'] = self.translate_path().decode('utf-8')
        env['SCRIPT_NAME'] = self.path.decode('utf-8')
        env['PATH_INFO'] = urllib.parse.urlunparse(
            ("", "", "", self.rest[0].decode('utf-8'), "", ""))
        env['QUERY_STRING'] = self.rest[1].decode('utf-8')
        if not self.host == self.client_address[0]:
            env['REMOTE_HOST'] = self.host
        env['REMOTE_ADDR'] = self.client_address[0]
        env['CONTENT_LENGTH'] = str(self.headers.get(b'content-length', b''), 'utf-8')
        for k in [b'USER_AGENT', b'COOKIE', b'ACCEPT', b'ACCEPT_CHARSET',
                  b'ACCEPT_ENCODING', b'ACCEPT_LANGUAGE', b'CONNECTION']:
            hdr = k.lower().replace(b"_", b"-")
            env['HTTP_%s' % k.upper().decode('utf-8')] = str(self.headers.get(hdr, b''), 'utf-8')
        os.environ.update(env)

    def redirect_resp(self, redirurl):
        """Return redirect message"""
        resp_line = b"%s 301 Moved Permanently\r\nLocation: %s\r\n" % (
            self.protocol, redirurl)
        self.close_when_done = True
        self.log(301)
        return [resp_line]

    def err_resp(self, code, msg):
        """Return an error message"""
        resp_line = b"%s %d %s\r\n" % (self.protocol, code, msg)
        self.close_when_done = True
        self.log(code)
        return [resp_line]

    def log(self, code):
        """Write a trace of the request on stderr"""
        if HTTP.logging:
            date_str = datetime.datetime.now().strftime('[%d/%b/%Y %H:%M:%S]')
            sys.stderr.write('%s - - %s "%s" %s\n' %
                             (self.host, date_str, self.requestline.decode('utf-8'), code))


# =======================================================================
# exit_err function. Exits with error code.
# =======================================================================
def exit_err(msg):
    sys.stderr.write(msg)
    exit(1)


# =======================================================================
# fork function. Calls fork and exits from parent process.
# =======================================================================
def fork():
    try:
        pid = os.fork()
        if pid > 0:
            sys.exit(0)
    except OSError as err:
        exit_err("fork failed: %d (%s)" % (err.errno, err.strerror))


# =======================================================================
# daemonize function. Sends current process to background and manages pidfile.
# =======================================================================
def daemonize(pidfile, user=None):
    # open pidfile descriptor
    try:
        pidf = open(pidfile, 'w+')
    except IOError:
        exit_err("could not open pidfile for writing: %s" % pidfile)

    # change user from root to custom
    if user:
        _, _, uid, gid, _, _, _ = pwd.getpwnam(user)
        os.setgid(gid)
        os.setuid(uid)

    # flush output buffers before forking to avoid printing something twice
    sys.stdout.flush()
    sys.stderr.flush()

    # do first fork
    fork()

    # decouple from parent environment
    os.chdir("/")
    os.setsid()
    os.umask(0)

    # do second fork
    fork()

    # redirect standard file descriptors
    nullin = open('/dev/null', 'r')
    nullout = open('/dev/null', 'a+')
    os.dup2(nullin.fileno(), sys.stdin.fileno())
    os.dup2(nullout.fileno(), sys.stdout.fileno())
    os.dup2(nullout.fileno(), sys.stderr.fileno())

    # write pidfile
    pidf.write("%d\n" % os.getpid())
    pidf.close()


if __name__ == "__main__":
    VERBOSE = False
    HOST = 'any'
    PORT = 9425
    ROOTPATH = "/usr/share/sfscgi"
    PIDFILE = None
    USER = None

    OPTS, ARGS = getopt.getopt(sys.argv[1:], "vhH:P:R:p:u:")
    for opt, val in OPTS:
        if opt == '-h':
            print("usage: %s [-H bind_host] [-P bind_port] [-R rootpath] [-v]\n" % sys.argv[0])
            print("-H bind_host : local address to listen on (default: any)")
            print("-P bind_port : port to listen on (default: 9425)")
            print("-R rootpath : local path to use as HTTP document root (default: /usr/share/sfscgi)")
            print("-v : log requests on stderr")
            print("-p : pidfile path, setting it triggers manual daemonization")
            print("-u : username of server owner, used in manual daemonization")
            sys.exit(0)
        elif opt == '-H':
            HOST = val
        elif opt == '-P':
            PORT = int(val)
        elif opt == '-R':
            ROOTPATH = val
        elif opt == '-v':
            VERBOSE = True
        elif opt == '-p':
            PIDFILE = val
        elif opt == '-u':
            USER = val

    # launch the server on the specified port
    SERVER = Server(HOST, PORT)
    if HOST != 'any':
        print("Asynchronous HTTP server running on %s:%s" % (HOST, PORT))
    else:
        print("Asynchronous HTTP server running on port %s" % PORT)
    HTTP.logging = bool(VERBOSE)
    HTTP.root = os.path.realpath(ROOTPATH)
    if PIDFILE:
        daemonize(PIDFILE, USER)
    loop(SERVER, HTTP)
