#!/usr/bin/python
#
# Copyright (C) 2011 Citrix Systems, Inc.
#
# This library is free software; you can redistribute it and/or modify
# it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; If not, see <http://www.gnu.org/licenses/>.
#

"""Overview:

        - Gather Xen I/O ring states
          (from %s/*/ring)

        - Update ring states every -T seconds.

        - Determine if rings are idle or make progress.

        - Determine if idle rings dropped notifications (%s).

        - Instruct stuck backends to reissue notifications.
"""

import os
import glob

class Pattern(object):
    """A regex pattern. Compiled on demand, then persisted."""

    def __init__(self, regex):
        self.regex     = regex
        self.__pattern = None

    def get(self):
        import re

        if not self.__pattern:
            self.__pattern = re.compile(self.regex)

        return self.__pattern

    def search(self, s):
        return self.get().search(s)

class XenBackend(object):
    """A Xen I/O backend."""

    SYSFS_BASEDIR = "/sys/devices/xen-backend"

    def __init__(self, rd, devid):
        self.rd    = int(rd)
        self.devid = int(devid)

    def __repr__(self):
        return "%s(%d, %d)" % (type(self).__name__,
                               self.rd, self.devid)

    def name(self):
        raise NotImplementedError

    def path(self):
        return "%s/%s" % (self.SYSFS_BASEDIR, self.name())

    _name_pattern = None

    @classmethod
    def from_name(cls, name):
        match = cls._name_pattern.search(name)
        if not match:
            raise Exception, "Malformed %s name: %s" % \
                (type(self).__name__, name)

        rd    = match.group(1)
        devid = match.group(2)

        return cls(rd, devid)

    _name_glob = None

    @classmethod
    def find(cls):
        paths = glob.glob("%s/%s" % (cls.SYSFS_BASEDIR,
                                     cls._name_glob))
        for path in paths:
            name = os.path.basename(path)
            yield cls.from_name(name)

    def find_rings(self):
        for ring in self.Ring.find(self):
            yield ring

    class Ring(object):

        def __init__(self, backend, name):
            self.backend = backend
            self.name    = name

        __size = None

        def key(self):
            return "%s/%s" % (self.backend.name(),
                              self.name)

        def __str__(self):
            return "%s(%s)" % (type(self).__name__, self.key())

        @classmethod
        def from_name(cls, backend, name):
            return cls(backend, name)

        _name_glob = None

        @classmethod
        def find(cls, backend):
            paths = glob.glob("%s/%s" % (backend.path(),
                                         cls._name_glob))
            for path in paths:
                name = os.path.basename(path)
                yield cls.from_name(backend, name)

        def path(self):
            return "%s/%s" % (self.backend.path(),
                              self.name)

        def read(self):
            state = RingState.from_sysfs(self.path())
            return state

        def write(self, cmd):
            f = file(self.path(), 'w')
            try:
                f.write(cmd.rstrip())
            finally:
                f.close()

        def kick(self):
            self.write("kick")

        def poll(self):
            self.write("poll")

    __ring = None

    TYPES = {}
    XEN_BACKEND_NAME = None

    @classmethod
    def register(cls):
        XenBackend.TYPES[cls.XEN_BACKEND_NAME] = cls

class VBD(XenBackend):
    """Xen blkif backends."""

    XEN_BACKEND_NAME = 'vbd'

    _name_pattern = Pattern("vbd-(\d+)-(\d+)")
    _name_glob    = "vbd-*-*"

    def name(self):
        return "vbd-%d-%d" % (self.rd, self.devid)

    class Ring(XenBackend.Ring):
        _name_glob = "io_ring"

VBD.register()

class VIF(XenBackend):
    """Xen netif backends."""

    XEN_BACKEND_NAME = 'vif'

    _name_pattern = Pattern("vif-(\d+)-(\d+)")
    _name_glob    = "vif-*-*"

    def name(self):
        return "vif-%d-%d" % (self.rd, self.devid)

    class Ring(XenBackend.Ring):
        _name_glob = "{rx,tx}_ring"

#VIF.register()

class RingState(object):
    """Overall backend ring state. Comprising req and rsp queue
    indexes, and analysis."""

    def __init__(self, size, req, rsp):
        self.size = int(size)
        self.req  = req
        self.rsp  = rsp

    _size_pattern = Pattern("nr_ents (\d+)")

    @classmethod
    def from_sysfs(cls, path):

        f = file(path, "r")
        try:
            s = f.read()
        finally:
            f.close()

        try:
            (_nr_ents, _req, _rsp, _) = s.split("\n")

            match   = cls._size_pattern.search(_nr_ents)
            nr_ents = int(match.group(1))

        except Exception, e:
            raise Exception, "Malformed %s input: %s (%s)" % \
                (cls.__name__, repr(s), str(e))

        req = cls.Req.from_sysfs(_req, size=nr_ents)
        rsp = cls.Rsp.from_sysfs(_rsp, size=nr_ents)

        return cls(nr_ents, req, rsp)

    class Queue(dict):

        def __init__(self, size):
            self.size = int(size)

        prod = None

        @classmethod
        def from_sysfs(cls, line, **d):

            match = cls._pattern.search(line)
            if not match:
                raise Exception, "Malformed %s input: %s" % \
                    (cls.__name__, repr(s))

            i = iter(match.groups())
            for k in i:
                d[k] = i.next()

            return cls(**d)

        def is_consumed(self):
            return self.prod == self._cons()

    class Req(Queue):

        _pattern = Pattern("req (prod) (\d+) (cons) (\d+) (event) (\d+)")

        def __init__(self, prod, cons, event, **d):
            RingState.Queue.__init__(self, **d)
            self.prod  = int(prod)
            self.cons  = int(cons)
            self.event = int(event)

        def __repr__(self):
            return "%s(prod=%d, cons=%d, event=%d)" % \
                (type(self).__name__, self.prod, self.cons, self.event)

        def _cons(self):
            return self.cons

        def __eq__(self, other):
            return \
                self.prod  == other.prod and \
                self.cons  == other.cons and \
                self.event == other.event

    class Rsp(Queue):

        _pattern = Pattern("rsp (prod) (\d+) (pvt) (\d+) (event) (\d+)")

        def __init__(self, prod, pvt, event, **d):
            RingState.Queue.__init__(self, **d)
            self.prod  = int(prod)
            self.pvt   = int(pvt)
            self.event = int(event)

        def __repr__(self):
            return "%s(prod=%d, pvt=%d, event=%d)" % \
                (type(self).__name__, self.prod, self.pvt, self.event)

        def _cons(self):
            return self.event - 1

        def __eq__(self, other):
            return \
                self.prod  == other.prod and \
                self.pvt   == other.pvt  and \
                self.event == other.event

    def is_consumed(self):
        return \
            self.rsp.is_consumed() and \
            self.req.is_consumed()

    def is_pending(self):
        return self.rsp.prod != self.req.prod

    def kick(self, ring):
        action = False

        if not self.req.is_consumed():
            action |= True
            ring.poll()

        if not self.rsp.is_consumed():
            action |= True
            ring.kick()

        return action

    def __eq__(self, other):
        return \
            self.size == other.size and \
            self.req == other.req and \
            self.rsp == other.rsp

    def __repr__(self):
        return "%s(size=%d, %s, %s)" % \
            (type(self).__name__, self.size, self.req, self.rsp)

    def display(self):
        complete = { True: "complete", False: "pending" }

        io  = complete[not self.is_pending()]
        req = complete[self.req.is_consumed()]
        rsp = complete[self.rsp.is_consumed()]

        return "%s: io: %s, req: %s, rsp: %s" % (self, io, req, rsp)

class RingWatch(object):
    """State machine watching I/O individual ring state"""

    _NEW  = "_NEW"
    BUSY  = "BUSY"
    IDLE  = "IDLE"
    STCK  = "STCK"

    COMMENTS = { BUSY: "Message traffic observed (OK)",
                 IDLE: "No messages observed (Ring OK, I/O depends)",
                 STCK: "No pending req/rsp consumer progress observed (BUG)" }

    def __init__(self, ring, state):
        self.ring   = ring
        self.state  = state
        self.status = RingWatch._NEW

    @classmethod
    def new(cls, ring):
        state = ring.read()
        return cls(ring, state)

    def __str__(self):
        return "%s(%s)[%s]" % \
            (type(self).__name__, self.ring.key(), self.status)

    def is_stuck(self):
        return self.status == self.STCK

    def is_idle(self):
        return self.status == self.IDLE

    def kick(self):
        if self.is_stuck():
            return self.state.kick(self.ring)

    def update(self):

        prev = self.state
        curr = self.ring.read()

        if curr == prev:
            if not curr.is_consumed():
                self.status = self.STCK
            else:
                self.status = self.IDLE
        else:
            self.status = self.BUSY

        self.state = curr

    def display(self):
        return "%s: %s" % (self,
                           self.state.display())

class WatchList(object):
    """Managed collection of I/O rings under surveillance."""

    def __init__(self, gen):
        self.gen  = gen
        self.list = {}

    def update(self):

        # NB. clear the watch list, then rebuild it. new entries get
        # added, existing ones updates, those gone discarded.
        prev      = self.list
        self.list = {}

        for ring in self.gen():

            key   = ring.key()
            entry = prev.get(key)

            try:
                if not entry:
                    entry = RingWatch.new(ring)
                else:
                    entry.update()

            except IOError, e:
                pass
                # NB. racing unplug, any ring.read() may raise.
                # nothing left to memorize then.
            else:
                self.list[key] = entry

    def __iter__(self):
        return self.list.itervalues()

    def pending(self):
        for entry in self:
            if entry.is_idle() and entry.state.is_pending():
                yield entry

    def stuck(self):
        for entry in self:
            if entry.is_stuck():
                yield entry

    def kick(self):
        for entry in self.stuck():
            try:
                entry.kick()
            except IOError:
                # NB. racing unplug, any ring.write() may raise.
                pass

if __name__ == '__main__':
    from sys import argv, stdout, stderr, exit
    from getopt import gnu_getopt, GetoptError
    from pprint import pprint

    DEFAULT_PERIOD = 1 # secs

    verbose  = 0
    period   = DEFAULT_PERIOD
    backends = XenBackend.TYPES.values()
    kick     = False
    iowatch  = False

    OPTIONS = ((('h', 'help'),
                "Print this help screen."),

               (('v', 'verbose'),
                "Increase output verbosity level (use n-times)."),

               (('I', 'io'),
                "Watch out for stuck I/O (not messaging), too. (%s)" % \
                    (iowatch)),

               (('t', 'types'),
                "Comma separated list of backend types to watch. (%s)" % \
                    ",".join(map(lambda t: t.XEN_BACKEND_NAME, backends))),

               (('T', 'period'),
                "Watch update period. (%d) [secs]" % \
                    (period)),

               (('k', 'kick'),
                "Kick broken guests out of cardiac arrest. (%s)" % \
                    (kick))
               )

    COMMANDS = {"check":
                    "Single iteration quick test (takes -T seconds)."}

    def usage(stream):
        prog = os.path.basename(argv[0])

        print >>stream

        print >>stream, "Usage:"
        print >>stream, "\t%s [options] {%s}" % (prog, "|".join(COMMANDS))

        print >>stream

        print >>stream, "Commands:"
        for (name, desc) in COMMANDS.iteritems():
            print >>stream, "\t%s: \t%s" % (name, desc)

        print >>stream

        print >>stream, "Options:"
        for ((short, _long), desc) in OPTIONS:
            print >>stream, "\t-%s, --%s: \t%s" % (short, _long, desc)

        print >>stream

    def fail(msg = None):
        if msg: print >>stderr, "Error: %s" % msg
        usage(stderr)
        exit(1)

    def help():

        usage(stdout)

        print __doc__ % (XenBackend.SYSFS_BASEDIR, RingWatch.STCK)

        print "Backend Types:"
        for k, v in XenBackend.TYPES.iteritems():
            print "\t%s: \t%s (%s)" % (k, v.__doc__, v._name_glob)

        print
        print "Ring States:"
        for k, v in RingWatch.COMMENTS.iteritems():
            print "\t%s: \t%s" % (k, v)

        print

    try:
        opts, args = gnu_getopt(argv[1:],
                                "hIkt:vT:",
                                ["help",
                                 "io",
                                 "kick",
                                 "type=",
                                 "verbose",
                                 "period="])
    except GetoptError, e:
        fail(str(e))

    for (o, arg) in opts:
        try:
            if o in ('-h', '--help'):
                help()
                exit(0)

            elif o in ['-v', '--verbose']:
                verbose += 1

            elif o in ['-I', '--io']:
                iowatch = True

            elif o in ('-T', '--period'):
                period = int(arg)

            elif o in ('-t', '--type'):
                backends = ",".split(arg)
                backends = map(lambda t: XenBackend.TYPES[t], backends)

            elif o in ('-k', '--kick'):
                kick = True

            else:
                raise "BUG: option %s unhandled." % o

        except ValueError:
            fail("%s: invalid argument '%s'." % (o, arg))

    try:
        cmd = args[0]
    except IndexError:
        fail("Missing command.")

    def ring_select():
        for _type in backends:
            for backend in _type.find():
                for ring in backend.find_rings():
                    yield ring

    def show(entries):
        for watch in entries:
            print watch.display()

    def pause():
        import time
        time.sleep(period)

    watches = WatchList(ring_select)

    if cmd == "check":

        # init
        watches.update()

        if verbose >= 2:
            show(watches)

        # watch for one round
        pause()
        watches.update()

        # show result
        crit  = list(watches.stuck())
        stuck = bool(crit)

        if (iowatch):
            crit.extend(watches.pending())

        if verbose >= 1:
            show(watches)
        elif crit:
            show(crit)

        if stuck and kick:
            # deal with it
            watches.kick()

    else:
        fail("Invalid command.")