Merge pull request #282 from campagnola/image-testing
TST: Add image testing
This commit is contained in:
commit
856c5eaadf
@ -1179,10 +1179,9 @@ def imageToArray(img, copy=False, transpose=True):
|
||||
# If this works on all platforms, then there is no need to use np.asarray..
|
||||
arr = np.frombuffer(ptr, np.ubyte, img.byteCount())
|
||||
|
||||
if fmt == img.Format_RGB32:
|
||||
arr = arr.reshape(img.height(), img.width(), 3)
|
||||
elif fmt == img.Format_ARGB32 or fmt == img.Format_ARGB32_Premultiplied:
|
||||
arr = arr.reshape(img.height(), img.width(), 4)
|
||||
if fmt == img.Format_RGB32:
|
||||
arr[...,3] = 255
|
||||
|
||||
if copy:
|
||||
arr = arr.copy()
|
||||
|
@ -126,10 +126,18 @@ class PlotCurveItem(GraphicsObject):
|
||||
|
||||
## Get min/max (or percentiles) of the requested data range
|
||||
if frac >= 1.0:
|
||||
# include complete data range
|
||||
# first try faster nanmin/max function, then cut out infs if needed.
|
||||
b = (np.nanmin(d), np.nanmax(d))
|
||||
if any(np.isinf(b)):
|
||||
mask = np.isfinite(d)
|
||||
d = d[mask]
|
||||
b = (d.min(), d.max())
|
||||
|
||||
elif frac <= 0.0:
|
||||
raise Exception("Value for parameter 'frac' must be > 0. (got %s)" % str(frac))
|
||||
else:
|
||||
# include a percentile of data range
|
||||
mask = np.isfinite(d)
|
||||
d = d[mask]
|
||||
b = np.percentile(d, [50 * (1 - frac), 50 * (1 + frac)])
|
||||
|
34
pyqtgraph/graphicsItems/tests/test_PlotCurveItem.py
Normal file
34
pyqtgraph/graphicsItems/tests/test_PlotCurveItem.py
Normal file
@ -0,0 +1,34 @@
|
||||
import numpy as np
|
||||
import pyqtgraph as pg
|
||||
from pyqtgraph.tests import assertImageApproved
|
||||
|
||||
|
||||
def test_PlotCurveItem():
|
||||
p = pg.GraphicsWindow()
|
||||
p.ci.layout.setContentsMargins(4, 4, 4, 4) # default margins vary by platform
|
||||
v = p.addViewBox()
|
||||
p.resize(200, 150)
|
||||
data = np.array([1,4,2,3,np.inf,5,7,6,-np.inf,8,10,9,np.nan,-1,-2,0])
|
||||
c = pg.PlotCurveItem(data)
|
||||
v.addItem(c)
|
||||
v.autoRange()
|
||||
|
||||
# Check auto-range works. Some platform differences may be expected..
|
||||
checkRange = np.array([[-1.1457564053237301, 16.145756405323731], [-3.076811473165955, 11.076811473165955]])
|
||||
assert np.allclose(v.viewRange(), checkRange)
|
||||
|
||||
assertImageApproved(p, 'plotcurveitem/connectall', "Plot curve with all points connected.")
|
||||
|
||||
c.setData(data, connect='pairs')
|
||||
assertImageApproved(p, 'plotcurveitem/connectpairs', "Plot curve with pairs connected.")
|
||||
|
||||
c.setData(data, connect='finite')
|
||||
assertImageApproved(p, 'plotcurveitem/connectfinite', "Plot curve with finite points connected.")
|
||||
|
||||
c.setData(data, connect=np.array([1,1,1,0,1,1,0,0,1,0,0,0,1,1,0,0]))
|
||||
assertImageApproved(p, 'plotcurveitem/connectarray', "Plot curve with connection array.")
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_PlotCurveItem()
|
@ -1 +1,2 @@
|
||||
from .image_testing import assertImageApproved
|
||||
from .ui_testing import mousePress, mouseMove, mouseRelease, mouseDrag, mouseClick
|
||||
|
533
pyqtgraph/tests/image_testing.py
Normal file
533
pyqtgraph/tests/image_testing.py
Normal file
@ -0,0 +1,533 @@
|
||||
# Image-based testing borrowed from vispy
|
||||
|
||||
"""
|
||||
Procedure for unit-testing with images:
|
||||
|
||||
1. Run unit tests at least once; this initializes a git clone of
|
||||
pyqtgraph/test-data in ~/.pyqtgraph.
|
||||
|
||||
2. Run individual test scripts with the PYQTGRAPH_AUDIT environment variable set:
|
||||
|
||||
$ PYQTGRAPH_AUDIT=1 python pyqtgraph/graphicsItems/tests/test_PlotCurveItem.py
|
||||
|
||||
Any failing tests will
|
||||
display the test results, standard image, and the differences between the
|
||||
two. If the test result is bad, then press (f)ail. If the test result is
|
||||
good, then press (p)ass and the new image will be saved to the test-data
|
||||
directory.
|
||||
|
||||
3. After adding or changing test images, create a new commit:
|
||||
|
||||
$ cd ~/.pyqtgraph/test-data
|
||||
$ git add ...
|
||||
$ git commit -a
|
||||
|
||||
4. Look up the most recent tag name from the `testDataTag` global variable
|
||||
below. Increment the tag name by 1 and create a new tag in the test-data
|
||||
repository:
|
||||
|
||||
$ git tag test-data-NNN
|
||||
$ git push --tags origin master
|
||||
|
||||
This tag is used to ensure that each pyqtgraph commit is linked to a specific
|
||||
commit in the test-data repository. This makes it possible to push new
|
||||
commits to the test-data repository without interfering with existing
|
||||
tests, and also allows unit tests to continue working on older pyqtgraph
|
||||
versions.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# This is the name of a tag in the test-data repository that this version of
|
||||
# pyqtgraph should be tested against. When adding or changing test images,
|
||||
# create and push a new tag and update this variable.
|
||||
testDataTag = 'test-data-3'
|
||||
|
||||
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
import inspect
|
||||
import base64
|
||||
import subprocess as sp
|
||||
import numpy as np
|
||||
|
||||
if sys.version[0] >= '3':
|
||||
import http.client as httplib
|
||||
import urllib.parse as urllib
|
||||
else:
|
||||
import httplib
|
||||
import urllib
|
||||
from ..Qt import QtGui, QtCore
|
||||
from .. import functions as fn
|
||||
from .. import GraphicsLayoutWidget
|
||||
from .. import ImageItem, TextItem
|
||||
|
||||
|
||||
tester = None
|
||||
|
||||
|
||||
def getTester():
|
||||
global tester
|
||||
if tester is None:
|
||||
tester = ImageTester()
|
||||
return tester
|
||||
|
||||
|
||||
def assertImageApproved(image, standardFile, message=None, **kwargs):
|
||||
"""Check that an image test result matches a pre-approved standard.
|
||||
|
||||
If the result does not match, then the user can optionally invoke a GUI
|
||||
to compare the images and decide whether to fail the test or save the new
|
||||
image as the standard.
|
||||
|
||||
This function will automatically clone the test-data repository into
|
||||
~/.pyqtgraph/test-data. However, it is up to the user to ensure this repository
|
||||
is kept up to date and to commit/push new images after they are saved.
|
||||
|
||||
Run the test with the environment variable PYQTGRAPH_AUDIT=1 to bring up
|
||||
the auditing GUI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : (h, w, 4) ndarray
|
||||
standardFile : str
|
||||
The name of the approved test image to check against. This file name
|
||||
is relative to the root of the pyqtgraph test-data repository and will
|
||||
be automatically fetched.
|
||||
message : str
|
||||
A string description of the image. It is recommended to describe
|
||||
specific features that an auditor should look for when deciding whether
|
||||
to fail a test.
|
||||
|
||||
Extra keyword arguments are used to set the thresholds for automatic image
|
||||
comparison (see ``assertImageMatch()``).
|
||||
"""
|
||||
if isinstance(image, QtGui.QWidget):
|
||||
w = image
|
||||
image = np.zeros((w.height(), w.width(), 4), dtype=np.ubyte)
|
||||
qimg = fn.makeQImage(image, alpha=True, copy=False, transpose=False)
|
||||
painter = QtGui.QPainter(qimg)
|
||||
w.render(painter)
|
||||
painter.end()
|
||||
|
||||
if message is None:
|
||||
code = inspect.currentframe().f_back.f_code
|
||||
message = "%s::%s" % (code.co_filename, code.co_name)
|
||||
|
||||
# Make sure we have a test data repo available, possibly invoking git
|
||||
dataPath = getTestDataRepo()
|
||||
|
||||
# Read the standard image if it exists
|
||||
stdFileName = os.path.join(dataPath, standardFile + '.png')
|
||||
if not os.path.isfile(stdFileName):
|
||||
stdImage = None
|
||||
else:
|
||||
pxm = QtGui.QPixmap()
|
||||
pxm.load(stdFileName)
|
||||
stdImage = fn.imageToArray(pxm.toImage(), copy=True, transpose=False)
|
||||
|
||||
# If the test image does not match, then we go to audit if requested.
|
||||
try:
|
||||
if image.shape[2] != stdImage.shape[2]:
|
||||
raise Exception("Test result has different channel count than standard image"
|
||||
"(%d vs %d)" % (image.shape[2], stdImage.shape[2]))
|
||||
if image.shape != stdImage.shape:
|
||||
# Allow im1 to be an integer multiple larger than im2 to account
|
||||
# for high-resolution displays
|
||||
ims1 = np.array(image.shape).astype(float)
|
||||
ims2 = np.array(stdImage.shape).astype(float)
|
||||
sr = ims1 / ims2 if ims1[0] > ims2[0] else ims2 / ims1
|
||||
if (sr[0] != sr[1] or not np.allclose(sr, np.round(sr)) or
|
||||
sr[0] < 1):
|
||||
raise TypeError("Test result shape %s is not an integer factor"
|
||||
" different than standard image shape %s." %
|
||||
(ims1, ims2))
|
||||
sr = np.round(sr).astype(int)
|
||||
image = downsample(image, sr[0], axis=(0, 1)).astype(image.dtype)
|
||||
|
||||
assertImageMatch(image, stdImage, **kwargs)
|
||||
except Exception:
|
||||
if stdFileName in gitStatus(dataPath):
|
||||
print("\n\nWARNING: unit test failed against modified standard "
|
||||
"image %s.\nTo revert this file, run `cd %s; git checkout "
|
||||
"%s`\n" % (stdFileName, dataPath, standardFile))
|
||||
if os.getenv('PYQTGRAPH_AUDIT') == '1':
|
||||
sys.excepthook(*sys.exc_info())
|
||||
getTester().test(image, stdImage, message)
|
||||
stdPath = os.path.dirname(stdFileName)
|
||||
print('Saving new standard image to "%s"' % stdFileName)
|
||||
if not os.path.isdir(stdPath):
|
||||
os.makedirs(stdPath)
|
||||
img = fn.makeQImage(image, alpha=True, copy=False, transpose=False)
|
||||
img.save(stdFileName)
|
||||
else:
|
||||
if stdImage is None:
|
||||
raise Exception("Test standard %s does not exist. Set "
|
||||
"PYQTGRAPH_AUDIT=1 to add this image." % stdFileName)
|
||||
else:
|
||||
if os.getenv('TRAVIS') is not None:
|
||||
saveFailedTest(image, stdImage, standardFile)
|
||||
raise
|
||||
|
||||
|
||||
def assertImageMatch(im1, im2, minCorr=None, pxThreshold=50.,
|
||||
pxCount=0, maxPxDiff=None, avgPxDiff=None,
|
||||
imgDiff=None):
|
||||
"""Check that two images match.
|
||||
|
||||
Images that differ in shape or dtype will fail unconditionally.
|
||||
Further tests for similarity depend on the arguments supplied.
|
||||
|
||||
By default, images may have no pixels that gave a value difference greater
|
||||
than 50.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
im1 : (h, w, 4) ndarray
|
||||
Test output image
|
||||
im2 : (h, w, 4) ndarray
|
||||
Test standard image
|
||||
minCorr : float or None
|
||||
Minimum allowed correlation coefficient between corresponding image
|
||||
values (see numpy.corrcoef)
|
||||
pxThreshold : float
|
||||
Minimum value difference at which two pixels are considered different
|
||||
pxCount : int or None
|
||||
Maximum number of pixels that may differ
|
||||
maxPxDiff : float or None
|
||||
Maximum allowed difference between pixels
|
||||
avgPxDiff : float or None
|
||||
Average allowed difference between pixels
|
||||
imgDiff : float or None
|
||||
Maximum allowed summed difference between images
|
||||
|
||||
"""
|
||||
assert im1.ndim == 3
|
||||
assert im1.shape[2] == 4
|
||||
assert im1.dtype == im2.dtype
|
||||
|
||||
diff = im1.astype(float) - im2.astype(float)
|
||||
if imgDiff is not None:
|
||||
assert np.abs(diff).sum() <= imgDiff
|
||||
|
||||
pxdiff = diff.max(axis=2) # largest value difference per pixel
|
||||
mask = np.abs(pxdiff) >= pxThreshold
|
||||
if pxCount is not None:
|
||||
assert mask.sum() <= pxCount
|
||||
|
||||
maskedDiff = diff[mask]
|
||||
if maxPxDiff is not None and maskedDiff.size > 0:
|
||||
assert maskedDiff.max() <= maxPxDiff
|
||||
if avgPxDiff is not None and maskedDiff.size > 0:
|
||||
assert maskedDiff.mean() <= avgPxDiff
|
||||
|
||||
if minCorr is not None:
|
||||
with np.errstate(invalid='ignore'):
|
||||
corr = np.corrcoef(im1.ravel(), im2.ravel())[0, 1]
|
||||
assert corr >= minCorr
|
||||
|
||||
|
||||
def saveFailedTest(data, expect, filename):
|
||||
"""Upload failed test images to web server to allow CI test debugging.
|
||||
"""
|
||||
commit, error = runSubprocess(['git', 'rev-parse', 'HEAD'])
|
||||
name = filename.split('/')
|
||||
name.insert(-1, commit.strip())
|
||||
filename = '/'.join(name)
|
||||
host = 'data.pyqtgraph.org'
|
||||
|
||||
# concatenate data, expect, and diff into a single image
|
||||
ds = data.shape
|
||||
es = expect.shape
|
||||
|
||||
shape = (max(ds[0], es[0]) + 4, ds[1] + es[1] + 8 + max(ds[1], es[1]), 4)
|
||||
img = np.empty(shape, dtype=np.ubyte)
|
||||
img[..., :3] = 100
|
||||
img[..., 3] = 255
|
||||
|
||||
img[2:2+ds[0], 2:2+ds[1], :ds[2]] = data
|
||||
img[2:2+es[0], ds[1]+4:ds[1]+4+es[1], :es[2]] = expect
|
||||
|
||||
diff = makeDiffImage(data, expect)
|
||||
img[2:2+diff.shape[0], -diff.shape[1]-2:-2] = diff
|
||||
|
||||
png = makePng(img)
|
||||
|
||||
conn = httplib.HTTPConnection(host)
|
||||
req = urllib.urlencode({'name': filename,
|
||||
'data': base64.b64encode(png)})
|
||||
conn.request('POST', '/upload.py', req)
|
||||
response = conn.getresponse().read()
|
||||
conn.close()
|
||||
print("\nImage comparison failed. Test result: %s %s Expected result: "
|
||||
"%s %s" % (data.shape, data.dtype, expect.shape, expect.dtype))
|
||||
print("Uploaded to: \nhttp://%s/data/%s" % (host, filename))
|
||||
if not response.startswith(b'OK'):
|
||||
print("WARNING: Error uploading data to %s" % host)
|
||||
print(response)
|
||||
|
||||
|
||||
def makePng(img):
|
||||
"""Given an array like (H, W, 4), return a PNG-encoded byte string.
|
||||
"""
|
||||
io = QtCore.QBuffer()
|
||||
qim = fn.makeQImage(img, alpha=False)
|
||||
qim.save(io, format='png')
|
||||
png = io.data().data().encode()
|
||||
return png
|
||||
|
||||
|
||||
def makeDiffImage(im1, im2):
|
||||
"""Return image array showing the differences between im1 and im2.
|
||||
|
||||
Handles images of different shape. Alpha channels are not compared.
|
||||
"""
|
||||
ds = im1.shape
|
||||
es = im2.shape
|
||||
|
||||
diff = np.empty((max(ds[0], es[0]), max(ds[1], es[1]), 4), dtype=int)
|
||||
diff[..., :3] = 128
|
||||
diff[..., 3] = 255
|
||||
diff[:ds[0], :ds[1], :min(ds[2], 3)] += im1[..., :3]
|
||||
diff[:es[0], :es[1], :min(es[2], 3)] -= im2[..., :3]
|
||||
diff = np.clip(diff, 0, 255).astype(np.ubyte)
|
||||
return diff
|
||||
|
||||
|
||||
class ImageTester(QtGui.QWidget):
|
||||
"""Graphical interface for auditing image comparison tests.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.lastKey = None
|
||||
|
||||
QtGui.QWidget.__init__(self)
|
||||
self.resize(1200, 800)
|
||||
self.showFullScreen()
|
||||
|
||||
self.layout = QtGui.QGridLayout()
|
||||
self.setLayout(self.layout)
|
||||
|
||||
self.view = GraphicsLayoutWidget()
|
||||
self.layout.addWidget(self.view, 0, 0, 1, 2)
|
||||
|
||||
self.label = QtGui.QLabel()
|
||||
self.layout.addWidget(self.label, 1, 0, 1, 2)
|
||||
self.label.setWordWrap(True)
|
||||
font = QtGui.QFont("monospace", 14, QtGui.QFont.Bold)
|
||||
self.label.setFont(font)
|
||||
|
||||
self.passBtn = QtGui.QPushButton('Pass')
|
||||
self.failBtn = QtGui.QPushButton('Fail')
|
||||
self.layout.addWidget(self.passBtn, 2, 0)
|
||||
self.layout.addWidget(self.failBtn, 2, 1)
|
||||
|
||||
self.views = (self.view.addViewBox(row=0, col=0),
|
||||
self.view.addViewBox(row=0, col=1),
|
||||
self.view.addViewBox(row=0, col=2))
|
||||
labelText = ['test output', 'standard', 'diff']
|
||||
for i, v in enumerate(self.views):
|
||||
v.setAspectLocked(1)
|
||||
v.invertY()
|
||||
v.image = ImageItem()
|
||||
v.image.setAutoDownsample(True)
|
||||
v.addItem(v.image)
|
||||
v.label = TextItem(labelText[i])
|
||||
v.setBackgroundColor(0.5)
|
||||
|
||||
self.views[1].setXLink(self.views[0])
|
||||
self.views[1].setYLink(self.views[0])
|
||||
self.views[2].setXLink(self.views[0])
|
||||
self.views[2].setYLink(self.views[0])
|
||||
|
||||
def test(self, im1, im2, message):
|
||||
"""Ask the user to decide whether an image test passes or fails.
|
||||
|
||||
This method displays the test image, reference image, and the difference
|
||||
between the two. It then blocks until the user selects the test output
|
||||
by clicking a pass/fail button or typing p/f. If the user fails the test,
|
||||
then an exception is raised.
|
||||
"""
|
||||
self.show()
|
||||
if im2 is None:
|
||||
message += '\nImage1: %s %s Image2: [no standard]' % (im1.shape, im1.dtype)
|
||||
im2 = np.zeros((1, 1, 3), dtype=np.ubyte)
|
||||
else:
|
||||
message += '\nImage1: %s %s Image2: %s %s' % (im1.shape, im1.dtype, im2.shape, im2.dtype)
|
||||
self.label.setText(message)
|
||||
|
||||
self.views[0].image.setImage(im1.transpose(1, 0, 2))
|
||||
self.views[1].image.setImage(im2.transpose(1, 0, 2))
|
||||
diff = makeDiffImage(im1, im2).transpose(1, 0, 2)
|
||||
|
||||
self.views[2].image.setImage(diff)
|
||||
self.views[0].autoRange()
|
||||
|
||||
while True:
|
||||
QtGui.QApplication.processEvents()
|
||||
lastKey = self.lastKey
|
||||
|
||||
self.lastKey = None
|
||||
if lastKey in ('f', 'esc') or not self.isVisible():
|
||||
raise Exception("User rejected test result.")
|
||||
elif lastKey == 'p':
|
||||
break
|
||||
time.sleep(0.03)
|
||||
|
||||
for v in self.views:
|
||||
v.image.setImage(np.zeros((1, 1, 3), dtype=np.ubyte))
|
||||
|
||||
def keyPressEvent(self, event):
|
||||
if event.key() == QtCore.Qt.Key_Escape:
|
||||
self.lastKey = 'esc'
|
||||
else:
|
||||
self.lastKey = str(event.text()).lower()
|
||||
|
||||
|
||||
def getTestDataRepo():
|
||||
"""Return the path to a git repository with the required commit checked
|
||||
out.
|
||||
|
||||
If the repository does not exist, then it is cloned from
|
||||
https://github.com/pyqtgraph/test-data. If the repository already exists
|
||||
then the required commit is checked out.
|
||||
"""
|
||||
global testDataTag
|
||||
|
||||
dataPath = os.path.join(os.path.expanduser('~'), '.pyqtgraph', 'test-data')
|
||||
gitPath = 'https://github.com/pyqtgraph/test-data'
|
||||
gitbase = gitCmdBase(dataPath)
|
||||
|
||||
if os.path.isdir(dataPath):
|
||||
# Already have a test-data repository to work with.
|
||||
|
||||
# Get the commit ID of testDataTag. Do a fetch if necessary.
|
||||
try:
|
||||
tagCommit = gitCommitId(dataPath, testDataTag)
|
||||
except NameError:
|
||||
cmd = gitbase + ['fetch', '--tags', 'origin']
|
||||
print(' '.join(cmd))
|
||||
sp.check_call(cmd)
|
||||
try:
|
||||
tagCommit = gitCommitId(dataPath, testDataTag)
|
||||
except NameError:
|
||||
raise Exception("Could not find tag '%s' in test-data repo at"
|
||||
" %s" % (testDataTag, dataPath))
|
||||
except Exception:
|
||||
if not os.path.exists(os.path.join(dataPath, '.git')):
|
||||
raise Exception("Directory '%s' does not appear to be a git "
|
||||
"repository. Please remove this directory." %
|
||||
dataPath)
|
||||
else:
|
||||
raise
|
||||
|
||||
# If HEAD is not the correct commit, then do a checkout
|
||||
if gitCommitId(dataPath, 'HEAD') != tagCommit:
|
||||
print("Checking out test-data tag '%s'" % testDataTag)
|
||||
sp.check_call(gitbase + ['checkout', testDataTag])
|
||||
|
||||
else:
|
||||
print("Attempting to create git clone of test data repo in %s.." %
|
||||
dataPath)
|
||||
|
||||
parentPath = os.path.split(dataPath)[0]
|
||||
if not os.path.isdir(parentPath):
|
||||
os.makedirs(parentPath)
|
||||
|
||||
if os.getenv('TRAVIS') is not None:
|
||||
# Create a shallow clone of the test-data repository (to avoid
|
||||
# downloading more data than is necessary)
|
||||
os.makedirs(dataPath)
|
||||
cmds = [
|
||||
gitbase + ['init'],
|
||||
gitbase + ['remote', 'add', 'origin', gitPath],
|
||||
gitbase + ['fetch', '--tags', 'origin', testDataTag,
|
||||
'--depth=1'],
|
||||
gitbase + ['checkout', '-b', 'master', 'FETCH_HEAD'],
|
||||
]
|
||||
else:
|
||||
# Create a full clone
|
||||
cmds = [['git', 'clone', gitPath, dataPath]]
|
||||
|
||||
for cmd in cmds:
|
||||
print(' '.join(cmd))
|
||||
rval = sp.check_call(cmd)
|
||||
if rval == 0:
|
||||
continue
|
||||
raise RuntimeError("Test data path '%s' does not exist and could "
|
||||
"not be created with git. Please create a git "
|
||||
"clone of %s at this path." %
|
||||
(dataPath, gitPath))
|
||||
|
||||
return dataPath
|
||||
|
||||
|
||||
def gitCmdBase(path):
|
||||
return ['git', '--git-dir=%s/.git' % path, '--work-tree=%s' % path]
|
||||
|
||||
|
||||
def gitStatus(path):
|
||||
"""Return a string listing all changes to the working tree in a git
|
||||
repository.
|
||||
"""
|
||||
cmd = gitCmdBase(path) + ['status', '--porcelain']
|
||||
return runSubprocess(cmd, stderr=None, universal_newlines=True)
|
||||
|
||||
|
||||
def gitCommitId(path, ref):
|
||||
"""Return the commit id of *ref* in the git repository at *path*.
|
||||
"""
|
||||
cmd = gitCmdBase(path) + ['show', ref]
|
||||
try:
|
||||
output = runSubprocess(cmd, stderr=None, universal_newlines=True)
|
||||
except sp.CalledProcessError:
|
||||
print(cmd)
|
||||
raise NameError("Unknown git reference '%s'" % ref)
|
||||
commit = output.split('\n')[0]
|
||||
assert commit[:7] == 'commit '
|
||||
return commit[7:]
|
||||
|
||||
|
||||
def runSubprocess(command, return_code=False, **kwargs):
|
||||
"""Run command using subprocess.Popen
|
||||
|
||||
Similar to subprocess.check_output(), which is not available in 2.6.
|
||||
|
||||
Run command and wait for command to complete. If the return code was zero
|
||||
then return, otherwise raise CalledProcessError.
|
||||
By default, this will also add stdout= and stderr=subproces.PIPE
|
||||
to the call to Popen to suppress printing to the terminal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
command : list of str
|
||||
Command to run as subprocess (see subprocess.Popen documentation).
|
||||
**kwargs : dict
|
||||
Additional kwargs to pass to ``subprocess.Popen``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
stdout : str
|
||||
Stdout returned by the process.
|
||||
"""
|
||||
# code adapted with permission from mne-python
|
||||
use_kwargs = dict(stderr=None, stdout=sp.PIPE)
|
||||
use_kwargs.update(kwargs)
|
||||
|
||||
p = sp.Popen(command, **use_kwargs)
|
||||
output = p.communicate()[0]
|
||||
|
||||
# communicate() may return bytes, str, or None depending on the kwargs
|
||||
# passed to Popen(). Convert all to unicode str:
|
||||
output = '' if output is None else output
|
||||
output = output.decode('utf-8') if isinstance(output, bytes) else output
|
||||
|
||||
if p.returncode != 0:
|
||||
print(output)
|
||||
err_fun = sp.CalledProcessError.__init__
|
||||
if 'output' in inspect.getargspec(err_fun).args:
|
||||
raise sp.CalledProcessError(p.returncode, command, output)
|
||||
else:
|
||||
raise sp.CalledProcessError(p.returncode, command)
|
||||
|
||||
return output
|
Loading…
Reference in New Issue
Block a user