r7287 - in firehose: . firehose/jobs



Author: walters
Date: 2008-02-06 11:01:42 -0600 (Wed, 06 Feb 2008)
New Revision: 7287

Added:
   firehose/firehose/jobs/logutil.py
Modified:
   firehose/dev.cfg
   firehose/firehose/jobs/master.py
   firehose/firehose/jobs/poller.py
Log:
Flesh out sending data to master.

Add logutil to print exceptions.

Do If-Modified-Since correctly.


Modified: firehose/dev.cfg
===================================================================
--- firehose/dev.cfg	2008-02-05 20:29:29 UTC (rev 7286)
+++ firehose/dev.cfg	2008-02-06 17:01:42 UTC (rev 7287)
@@ -12,8 +12,9 @@
 # sqlobject.dburi="sqlite:///file_name_and_path"
 
 firehose.taskdbpath="%(current_dir_uri)s/dev-tasks.sqlite"
-firehose.masterhost="localhost"
+firehose.masterhost="localhost:6676"
 firehose.slaveport="6677"
+firehose.clienturl="http://localhost:8080/extservice/notify-polling-tasks";
 
 # if you are using a database or table type without transactions
 # (MySQL default, for example), you should turn off transactions

Added: firehose/firehose/jobs/logutil.py
===================================================================
--- firehose/firehose/jobs/logutil.py	2008-02-05 20:29:29 UTC (rev 7286)
+++ firehose/firehose/jobs/logutil.py	2008-02-06 17:01:42 UTC (rev 7287)
@@ -0,0 +1,12 @@
+#!/usr/bin/python
+
+def log_except(logger=None, text=''):
+    def annotate(func):
+        def _exec_cb(*args, **kwargs):
+            try:
+                return func(*args, **kwargs)
+            except:
+                log_target = logger or logging
+                log_target.exception('Exception in callback%s', text and (': '+text) or '')
+        return _exec_cb
+    return annotate

Modified: firehose/firehose/jobs/master.py
===================================================================
--- firehose/firehose/jobs/master.py	2008-02-05 20:29:29 UTC (rev 7286)
+++ firehose/firehose/jobs/master.py	2008-02-06 17:01:42 UTC (rev 7287)
@@ -1,7 +1,7 @@
 #!/usr/bin/python
 
 import os,sys,re,heapq,time,httplib,logging,threading
-import traceback
+import traceback,urlparse
 import BaseHTTPServer,urllib
 
 if sys.version_info[0] < 2 or sys.version_info[1] < 5:
@@ -17,8 +17,9 @@
 import turbogears
 from turbogears import config
 
-from tasks import TaskEntry
+from firehose.jobs.tasks import TaskEntry
 from firehose.jobs.poller import TaskPoller
+from firehose.jobs.logutil import log_except
 
 _logger = logging.getLogger('firehose.Master')
 _logger.debug("hello master!")
@@ -64,10 +65,15 @@
     def __init__(self):
         global _instance
         assert _instance is None
-        self.__tasks = []
+        self.__tasks_queue = [] # priority queue
+        self.__tasks_map = {} # maps id -> task
+        self.__changed_buffer = []
+        self.__changed_thread_queued = False
         self.__poll_task = None        
         self.__task_lock = threading.Lock()
         
+        self.__client_url = config.get('firehose.clienturl')
+        
         # Default to one slave on localhost
         self.__worker_endpoints = ['localhost:%d' % (int(config.get('firehose.slaveport')),)]
         _logger.debug("worker endpoints are %r", self.__worker_endpoints)
@@ -90,21 +96,33 @@
         curtime = time.time()
         for key,prev_hash,prev_time in cursor.execute('''SELECT key,prev_hash,prev_time from Tasks'''):
             task = QueuedTask(curtime, TaskEntry(key, prev_hash, prev_time))
-            heapq.heappush(self.__tasks, task)
+            heapq.heappush(self.__tasks_queue, task)
+            self.__tasks_map[key] = task
         conn.close()
-        _logger.debug("%d queued tasks", len(self.__tasks))
+        _logger.debug("%d queued tasks", len(self.__tasks_queue))
     
-    def __add_task_for_key(self, key):
-        try:
-            self.__task_lock.acquire()        
+    def __add_task_keys_unlocked(self, keys):
+        for key in keys:
+            if key in self.__tasks_map:
+                continue 
             task = TaskEntry(key, None, None)
-            for qtask in self.__tasks:
-                if qtask.task == task:
-                    return qtask
             qtask = QueuedTask(time.time(), task)
-            self.__tasks.append(qtask)
+            self.__tasks_queue.append(qtask)
+            self.__tasks_map[key] = task
+            
+    def __add_task_keys(self, keys):
+        try:
+            self.__task_lock.acquire()
+            self.__add_task_keys_unlocked(keys)
         finally:
             self.__task_lock.release()
+    
+    def __add_task_for_key(self, key):
+        try:
+            self.__task_lock.acquire()
+            self.__add_task_keys_unlocked([key])
+        finally:
+            self.__task_lock.release()
             
     def add_feed(self, feedurl):
         taskkey = 'feed/' + urllib.quote(feedurl)
@@ -119,8 +137,9 @@
     
     def add_tasks(self, taskkeys):
         _logger.debug("adding %d task keys", len(taskkeys))
-        for taskkey in taskkeys:
-            self.__add_task_for_key(taskkey)
+        # Append them to the in-memory state
+        self.__add_task_keys(taskkeys)
+        # Persist them
         try:
             conn = sqlite3.connect(self.__path, isolation_level=None)        
             cursor = conn.cursor()
@@ -135,10 +154,52 @@
         cursor.execute('''INSERT OR REPLACE INTO Tasks VALUES (?, ?, ?)''',
                        (taskkey, hashcode, timestamp))
     
+    @log_except(_logger)
+    def __push_changed(self):
+        try:
+            self.__task_lock.acquire()
+            self.__changed_thread_queued = False
+            changed = self.__changed_buffer
+            self.__changed_buffer = []
+        finally:
+            self.__task_lock.release()
+        jsonstr = simplejson.dumps(changed)
+        parsed = urlparse.urlparse(self.__client_url)
+        conn = httplib.HTTPConnection(parsed.hostname, parsed.port)
+        conn.request('POST', parsed.path or '/', jsonstr)
+        conn.close()        
+
+    def __append_changed(self, changed):
+        try:
+            self.__task_lock.acquire()
+            self.__changed_buffer.extend(changed)
+            if not self.__changed_thread_queued:
+                thr = threading.Thread(target=self.__push_changed)
+                thr.setDaemon(True)
+                thr.start()
+                self.__changed_thread_queued = True
+        finally:
+            self.__task_lock.release()
+
     def taskset_status(self, results):
         _logger.info("got %d results", len(results))
-        _logger.debug("results: %r", results    )
+        changed = []
         try:
+            self.__task_lock.acquire()
+            for (taskkey, hashcode, timestamp) in results:
+                try:
+                    curtask = self.__tasks_map[taskkey]
+                except KeyError, e:
+                    _logger.exception("failed to find task key %r", taskkey)
+                    continue
+                if curtask.task.prev_hash != hashcode:
+                    _logger.debug("task %r: new hash for %r differs from prev %r", 
+                                  taskkey, hashcode, curtask.task.prev_hash)
+                    changed.append(taskkey)
+        finally:
+            self.__task_lock.release()
+        self.__append_changed(changed)            
+        try:
             conn = sqlite3.connect(self.__path, isolation_level=None)
             cursor = conn.cursor()
             cursor.execute('''BEGIN''')   
@@ -164,6 +225,7 @@
     def __activate_workers(self):
         raise NotImplementedError()
     
+    @log_except(_logger)
     def __enqueue_taskset(self, worker, taskset):
         jsonstr = simplejson.dumps(taskset)
         conn = httplib.HTTPConnection(worker)
@@ -182,7 +244,7 @@
             i = 0 
             while True:          
                 try:
-                    task = heapq.heappop(self.__tasks)
+                    task = heapq.heappop(self.__tasks_queue)
                 except IndexError, e:
                     break
                 if i >= MAX_TASKSET_SIZE:
@@ -195,7 +257,7 @@
                     i += 1
                 eligible = task.eligibility < taskset_limit
                 task.eligibility = curtime + DEFAULT_POLL_TIME_SECS
-                heapq.heappush(self.__tasks, task)                 
+                heapq.heappush(self.__tasks_queue, task)                 
                 if eligible:
                     taskset.append((str(task.task), task.task.prev_hash, task.task.prev_timestamp))
                 else:
@@ -226,11 +288,11 @@
             self.__task_lock.acquire()
                     
             assert self.__poll_task is None
-            if len(self.__tasks) == 0:
+            if len(self.__tasks_queue) == 0:
                 _logger.debug("no tasks")
                 return
             curtime = time.time()
-            next_timeout = self.__tasks[0].eligibility - curtime
+            next_timeout = self.__tasks_queue[0].eligibility - curtime
             if immediate:
                 next_timeout = 1
             elif (next_timeout < MIN_POLL_TIME_SECS):

Modified: firehose/firehose/jobs/poller.py
===================================================================
--- firehose/firehose/jobs/poller.py	2008-02-05 20:29:29 UTC (rev 7286)
+++ firehose/firehose/jobs/poller.py	2008-02-06 17:01:42 UTC (rev 7287)
@@ -2,7 +2,7 @@
 
 import os,sys,re,heapq,time,Queue,sha,threading
 import BaseHTTPServer,httplib,urlparse,urllib
-from email.utils import formatdate,parsedate
+from email.utils import formatdate,parsedate_tz,mktime_tz
 import logging
 
 import boto
@@ -15,6 +15,8 @@
 import simplejson
 from turbogears import config
 
+from firehose.jobs.logutil import log_except
+
 _logger = logging.getLogger('firehose.Poller')
 
 aws_config_path = os.path.expanduser('~/.aws')
@@ -40,20 +42,23 @@
         try:
             _logger.info('Connecting to %r', targeturl)
             connection = httplib.HTTPConnection(parsedurl.hostname, parsedurl.port)
-            connection.request('GET', parsedurl.path,
-                               headers={'If-Modified-Since':
-                                        formatdate(prev_timestamp)})
+            headers = {}
+            if prev_timestamp is not None:
+                headers['If-Modified-Since'] = formatdate(prev_timestamp)            
+            connection.request('GET', parsedurl.path, headers=headers)
             response = connection.getresponse()
             if response.status == 304:
                 _logger.info("Got 304 Unmodified for %r", targeturl)
-                return (prev_hash, prev_timestamp) 
-            data = response.read()
-            hash = sha.new()
-            hash.update(data)
+                return (prev_hash, prev_timestamp)
+            hash = sha.new()            
+            buf = response.read(8192)
+            while buf:
+                hash.update(buf)
+                buf = response.read(8192)
             hash_hex = hash.hexdigest()
             timestamp_str = response.getheader('Last-Modified', None)
             if timestamp_str is not None:
-                timestamp = parsedate(timestamp_str)
+                timestamp = mktime_tz(parsedate_tz(timestamp_str))
             else:
                 _logger.debug("no last-modified for %r", targeturl)
                 timestamp = time.time()
@@ -88,22 +93,24 @@
         bindport = int(config.get('firehose.slaveport'))
         self.__server = BaseHTTPServer.HTTPServer(('', bindport), TaskRequestHandler)
         self.__active_collectors = set()
-        self.__master_hostport = (config.get('firehose.masterhost'), 8080)
+        self.__master_hostport = config.get('firehose.masterhost')
         
     def run_async(self):
         thr = threading.Thread(target=self.run)
         thr.setDaemon(True)
         thr.start()
         
+    @log_except(_logger)        
     def run(self):
         self.__server.serve_forever()
         
     def __send_results(self, results):
         dumped_results = simplejson.dumps(results)
-        connection = httplib.HTTPConnection(*(self.__master_hostport))
+        connection = httplib.HTTPConnection(self.__master_hostport)
         connection.request('POST', '/taskset_status', dumped_results,
                            headers={'Content-Type': 'text/javascript'})
         
+    @log_except(_logger)        
     def __run_collect_tasks(self, taskqueue, resultqueue):
         _logger.debug("doing join on taskqueue")
         taskqueue.join()
@@ -114,9 +121,11 @@
                 result = resultqueue.get(False)
                 results.append(result)
             except Queue.Empty:
-                break 
+                break
+        _logger.debug("sending %d results", len(results))            
         self.__send_results(results)
         
+    @log_except(_logger)        
     def __run_task(self, taskid, prev_hash, prev_timestamp, taskqueue, resultqueue):
         (family, tid) = taskid.split('/', 1)
         try:



[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]