[PATCH 4 of 4] statprof: record samples by thread (RFC)

Gregory Szorc gregory.szorc at gmail.com
Sat Nov 5 00:53:43 EDT 2016


# HG changeset patch
# User Gregory Szorc <gregory.szorc at gmail.com>
# Date 1478321340 25200
#      Fri Nov 04 21:49:00 2016 -0700
# Node ID dab278ccedfb41d7a255c313fc426918ef27014f
# Parent  dfd8e1f11e33fc1ba6cba429818ad0ea6e84dadc
statprof: record samples by thread (RFC)

Previously, statprof only recorded samples for the thread that
started the profiler the first time. If another thread started
profiling (such as the case with hgweb requests), samples for
that thread would not be recorded.

With this commit, we teach statprof to record samples for multiple
threads. We maintain a global set of threads that have profiling
enabled. The background thread that collects frames periodically
iterates active threads and records samples for threads with
profiling enabled.

When profiling is stopped, that thread is removed from the active
list and samples for that thread are returned.

I'm not terribly excited about the state of this patch. Specifically,
I don't like changing state.samples to a list of tuples. I think
recording samples in per-thread lists is preferable. I'm throwing
this patch out there to see how it is received. Previously, I had
implemented "stack based" profiling where the last thread to call
start() was active. People didn't seem to like this because it only
allowed measuring a single thread. This approach allows measuring
multiple threads.

diff --git a/mercurial/statprof.py b/mercurial/statprof.py
--- a/mercurial/statprof.py
+++ b/mercurial/statprof.py
@@ -137,6 +137,9 @@ def clock():
     times = os.times()
     return times[0] + times[1]
 
+def currenttid():
+    frame = inspect.currentframe()
+    return [k for k, f in sys._current_frames().items() if f == frame][0]
 
 ###########################################################################
 ## Collection data structures
@@ -165,6 +168,28 @@ class ProfileState(object):
 
         self.samples = []
 
+    def filter_tid(self, tid):
+        """Obtain a copy of state with only samples from a specific thread.
+
+        A side effect is the samples from the requested thread are pruned
+        from the original samples list.
+        """
+        s = ProfileState()
+        s.accumulated_time = self.accumulated_time
+        s.last_start_time = self.last_start_time
+        s.sample_interval = self.sample_interval
+
+        newsamples = []
+        for t in self.samples:
+            if t[0] == tid:
+                s.samples.append(t)
+            else:
+                newsamples.append(t)
+
+        self.samples = newsamples
+
+        return s
+
     def accumulate_time(self, stop_time):
         self.accumulated_time += stop_time - self.last_start_time
 
@@ -172,7 +197,7 @@ class ProfileState(object):
         return self.accumulated_time / len(self.samples)
 
 state = ProfileState()
-
+activethreads = set()
 
 class CodeSite(object):
     cache = {}
@@ -257,20 +282,23 @@ def profile_signal_handler(signum, frame
         now = clock()
         state.accumulate_time(now)
 
-        state.samples.append(Sample.from_frame(frame, state.accumulated_time))
+        state.samples.append((None,
+                              Sample.from_frame(frame, state.accumulated_time)))
 
         signal.setitimer(signal.ITIMER_PROF,
             state.sample_interval, 0.0)
         state.last_start_time = now
 
 stopthread = threading.Event()
-def samplerthread(tid):
+def samplerthread():
     while not stopthread.is_set():
         now = clock()
         state.accumulate_time(now)
 
-        frame = sys._current_frames()[tid]
-        state.samples.append(Sample.from_frame(frame, state.accumulated_time))
+        for tid, frame in sys._current_frames().items():
+            if tid in activethreads:
+                sample = Sample.from_frame(frame, state.accumulated_time)
+                state.samples.append((tid, sample))
 
         state.last_start_time = now
         time.sleep(state.sample_interval)
@@ -286,6 +314,8 @@ def is_active():
 lastmechanism = None
 def start(mechanism='thread'):
     '''Install the profiling signal handler, and start profiling.'''
+    activethreads.add(currenttid())
+
     state.profile_level += 1
     if state.profile_level == 1:
         state.last_start_time = clock()
@@ -300,14 +330,18 @@ def start(mechanism='thread'):
             signal.setitimer(signal.ITIMER_PROF,
                 rpt or state.sample_interval, 0.0)
         elif mechanism == 'thread':
-            frame = inspect.currentframe()
-            tid = [k for k, f in sys._current_frames().items() if f == frame][0]
             state.thread = threading.Thread(target=samplerthread,
-                                 args=(tid,), name="samplerthread")
+                                 args=(), name="samplerthread")
             state.thread.start()
 
 def stop():
     '''Stop profiling, and uninstall the profiling signal handler.'''
+    tid = currenttid()
+    try:
+        activethreads.remove(tid)
+    except KeyError:
+        pass
+
     state.profile_level -= 1
     if state.profile_level == 0:
         if lastmechanism == 'signal':
@@ -324,12 +358,13 @@ def stop():
         if statprofpath:
             save_data(statprofpath)
 
-    return state
+    return state.filter_tid(tid)
 
 def save_data(path):
     with open(path, 'w+') as file:
         file.write(str(state.accumulated_time) + '\n')
         for sample in state.samples:
+            sample = sample[1]
             time = str(sample.time)
             stack = sample.stack
             sites = ['\1'.join([s.path, str(s.lineno), s.function])
@@ -351,7 +386,7 @@ def load_data(path):
             sites.append(CodeSite.get(siteparts[0], int(siteparts[1]),
                         siteparts[2]))
 
-        state.samples.append(Sample(sites, time))
+        state.samples.append((None, Sample(sites, time)))
 
 
 
@@ -408,7 +443,7 @@ class SiteStats(object):
         stats = {}
 
         for sample in samples:
-            for i, site in enumerate(sample.stack):
+            for i, site in enumerate(sample[1].stack):
                 sitestat = stats.get(site)
                 if not sitestat:
                     sitestat = SiteStats(site)
@@ -546,6 +581,7 @@ def display_about_method(data, fp, funct
     children = {}
 
     for sample in data.samples:
+        sample = sample[1]
         for i, site in enumerate(sample.stack):
             if site.function == function and (not filename
                 or site.filename() == filename):
@@ -626,8 +662,9 @@ def display_hotpath(data, fp, limit=0.05
                     child.add(stack[i:], time)
 
     root = HotNode(None)
-    lasttime = data.samples[0].time
+    lasttime = data.samples[0][1].time
     for sample in data.samples:
+        sample = sample[1]
         root.add(sample.stack[::-1], sample.time - lasttime)
         lasttime = sample.time
 
@@ -689,6 +726,7 @@ def write_to_flame(data, fp, scriptpath=
 
     lines = {}
     for sample in data.samples:
+        sample = sample[1]
         sites = [s.function for s in sample.stack]
         sites.reverse()
         line = ';'.join(sites)
@@ -712,6 +750,7 @@ def write_to_json(data, fp):
     samples = []
 
     for sample in data.samples:
+        sample = sample[1]
         stack = []
 
         for frame in sample.stack:


More information about the Mercurial-devel mailing list