From b09182d19af815976917dba95db19bd0010f2efa Mon Sep 17 00:00:00 2001 From: Luke Campagnola <> Date: Fri, 26 Oct 2012 21:47:45 -0400 Subject: [PATCH] GLScatterPlotItem: use shader programs to allow specifying spot size by array Reorganized shader programs Infrastructure updates for OpenGL system --- SRTTransform3D.py | 12 +-- Transform3D.py | 35 ++++++++ Vector.py | 123 ++++++++++++++------------- __init__.py | 1 + examples/GLScatterPlotItem.py | 93 +++++++++++++++------ opengl/GLGraphicsItem.py | 44 ++++++++-- opengl/GLViewWidget.py | 12 ++- opengl/items/GLMeshItem.py | 10 ++- opengl/items/GLScatterPlotItem.py | 130 +++++++++++++++++++---------- opengl/shaders.py | 134 ++++++++++++++++++++++-------- 10 files changed, 411 insertions(+), 183 deletions(-) create mode 100644 Transform3D.py diff --git a/SRTTransform3D.py b/SRTTransform3D.py index 94c3df77..89b8ab13 100644 --- a/SRTTransform3D.py +++ b/SRTTransform3D.py @@ -6,12 +6,12 @@ import pyqtgraph as pg import numpy as np import scipy.linalg -class SRTTransform3D(QtGui.QMatrix4x4): +class SRTTransform3D(pg.Transform3D): """4x4 Transform matrix that can always be represented as a combination of 3 matrices: scale * rotate * translate This transform has no shear; angles are always preserved. """ def __init__(self, init=None): - QtGui.QMatrix4x4.__init__(self) + pg.Transform3D.__init__(self) self.reset() if init is None: return @@ -190,11 +190,11 @@ class SRTTransform3D(QtGui.QMatrix4x4): self.update() def update(self): - QtGui.QMatrix4x4.setToIdentity(self) + pg.Transform3D.setToIdentity(self) ## modifications to the transform are multiplied on the right, so we need to reverse order here. - QtGui.QMatrix4x4.translate(self, *self._state['pos']) - QtGui.QMatrix4x4.rotate(self, self._state['angle'], *self._state['axis']) - QtGui.QMatrix4x4.scale(self, *self._state['scale']) + pg.Transform3D.translate(self, *self._state['pos']) + pg.Transform3D.rotate(self, self._state['angle'], *self._state['axis']) + pg.Transform3D.scale(self, *self._state['scale']) def __repr__(self): return str(self.saveState()) diff --git a/Transform3D.py b/Transform3D.py new file mode 100644 index 00000000..aa948e28 --- /dev/null +++ b/Transform3D.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +from .Qt import QtCore, QtGui +import pyqtgraph as pg +import numpy as np + +class Transform3D(QtGui.QMatrix4x4): + """ + Extension of QMatrix4x4 with some helpful methods added. + """ + def __init__(self, *args): + QtGui.QMatrix4x4.__init__(self, *args) + + def matrix(self, nd=3): + if nd == 3: + return np.array(self.copyDataTo()).reshape(4,4) + elif nd == 2: + m = np.array(self.copyDataTo()).reshape(4,4) + m[2] = m[3] + m[:,2] = m[:,3] + return m[:3,:3] + else: + raise Exception("Argument 'nd' must be 2 or 3") + + def map(self, obj): + """ + Extends QMatrix4x4.map() to allow mapping (3, ...) arrays of coordinates + """ + if isinstance(obj, np.ndarray) and obj.ndim >= 2 and obj.shape[0] in (2,3): + return pg.transformCoordinates(self, obj) + else: + return QtGui.QMatrix4x4.map(self, obj) + + def inverted(self): + inv, b = QtGui.QMatrix4x4.inverted(self) + return Transform3D(inv), b \ No newline at end of file diff --git a/Vector.py b/Vector.py index 79da3162..e9c109d8 100644 --- a/Vector.py +++ b/Vector.py @@ -1,59 +1,64 @@ -# -*- coding: utf-8 -*- -""" -Vector.py - Extension of QVector3D which adds a few missing methods. -Copyright 2010 Luke Campagnola -Distributed under MIT/X11 license. See license.txt for more infomation. -""" - -from .Qt import QtGui, QtCore -import numpy as np - -class Vector(QtGui.QVector3D): - """Extension of QVector3D which adds a few helpful methods.""" - - def __init__(self, *args): - if len(args) == 1: - if isinstance(args[0], QtCore.QSizeF): - QtGui.QVector3D.__init__(self, float(args[0].width()), float(args[0].height()), 0) - return - elif isinstance(args[0], QtCore.QPoint) or isinstance(args[0], QtCore.QPointF): - QtGui.QVector3D.__init__(self, float(args[0].x()), float(args[0].y()), 0) - elif hasattr(args[0], '__getitem__'): - vals = list(args[0]) - if len(vals) == 2: - vals.append(0) - if len(vals) != 3: - raise Exception('Cannot init Vector with sequence of length %d' % len(args[0])) - QtGui.QVector3D.__init__(self, *vals) - return - elif len(args) == 2: - QtGui.QVector3D.__init__(self, args[0], args[1], 0) - return - QtGui.QVector3D.__init__(self, *args) - - def __len__(self): - return 3 - - #def __reduce__(self): - #return (Point, (self.x(), self.y())) - - def __getitem__(self, i): - if i == 0: - return self.x() - elif i == 1: - return self.y() - elif i == 2: - return self.z() - else: - raise IndexError("Point has no index %s" % str(i)) - - def __setitem__(self, i, x): - if i == 0: - return self.setX(x) - elif i == 1: - return self.setY(x) - elif i == 2: - return self.setZ(x) - else: - raise IndexError("Point has no index %s" % str(i)) - +# -*- coding: utf-8 -*- +""" +Vector.py - Extension of QVector3D which adds a few missing methods. +Copyright 2010 Luke Campagnola +Distributed under MIT/X11 license. See license.txt for more infomation. +""" + +from .Qt import QtGui, QtCore +import numpy as np + +class Vector(QtGui.QVector3D): + """Extension of QVector3D which adds a few helpful methods.""" + + def __init__(self, *args): + if len(args) == 1: + if isinstance(args[0], QtCore.QSizeF): + QtGui.QVector3D.__init__(self, float(args[0].width()), float(args[0].height()), 0) + return + elif isinstance(args[0], QtCore.QPoint) or isinstance(args[0], QtCore.QPointF): + QtGui.QVector3D.__init__(self, float(args[0].x()), float(args[0].y()), 0) + elif hasattr(args[0], '__getitem__'): + vals = list(args[0]) + if len(vals) == 2: + vals.append(0) + if len(vals) != 3: + raise Exception('Cannot init Vector with sequence of length %d' % len(args[0])) + QtGui.QVector3D.__init__(self, *vals) + return + elif len(args) == 2: + QtGui.QVector3D.__init__(self, args[0], args[1], 0) + return + QtGui.QVector3D.__init__(self, *args) + + def __len__(self): + return 3 + + #def __reduce__(self): + #return (Point, (self.x(), self.y())) + + def __getitem__(self, i): + if i == 0: + return self.x() + elif i == 1: + return self.y() + elif i == 2: + return self.z() + else: + raise IndexError("Point has no index %s" % str(i)) + + def __setitem__(self, i, x): + if i == 0: + return self.setX(x) + elif i == 1: + return self.setY(x) + elif i == 2: + return self.setZ(x) + else: + raise IndexError("Point has no index %s" % str(i)) + + def __iter__(self): + yield(self.x()) + yield(self.y()) + yield(self.z()) + \ No newline at end of file diff --git a/__init__.py b/__init__.py index bd7c2e76..dbb54ca9 100644 --- a/__init__.py +++ b/__init__.py @@ -165,6 +165,7 @@ from .WidgetGroup import * from .Point import Point from .Vector import Vector from .SRTTransform import SRTTransform +from .Transform3D import Transform3D from .SRTTransform3D import SRTTransform3D from .functions import * from .graphicsWindows import * diff --git a/examples/GLScatterPlotItem.py b/examples/GLScatterPlotItem.py index 16033520..e73eacd9 100644 --- a/examples/GLScatterPlotItem.py +++ b/examples/GLScatterPlotItem.py @@ -15,48 +15,91 @@ w.show() g = gl.GLGridItem() w.addItem(g) -#pos = np.empty((53, 3)) -#size = np.empty((53)) -#color = np.empty((53, 4)) -#pos[0] = (1,0,0); size[0] = 0.5; color[0] = (1.0, 0.0, 0.0, 0.5) -#pos[1] = (0,1,0); size[1] = 0.2; color[1] = (0.0, 0.0, 1.0, 0.5) -#pos[2] = (0,0,1); size[2] = 2./3.; color[2] = (0.0, 1.0, 0.0, 0.5) -#z = 0.5 -#d = 6.0 -#for i in range(3,53): - #pos[i] = (0,0,z) - #size[i] = 2./d - #color[i] = (0.0, 1.0, 0.0, 0.5) - #z *= 0.5 - #d *= 2.0 +## +## First example is a set of points with pxMode=False +## These demonstrate the ability to have points with real size down to a very small scale +## +pos = np.empty((53, 3)) +size = np.empty((53)) +color = np.empty((53, 4)) +pos[0] = (1,0,0); size[0] = 0.5; color[0] = (1.0, 0.0, 0.0, 0.5) +pos[1] = (0,1,0); size[1] = 0.2; color[1] = (0.0, 0.0, 1.0, 0.5) +pos[2] = (0,0,1); size[2] = 2./3.; color[2] = (0.0, 1.0, 0.0, 0.5) + +z = 0.5 +d = 6.0 +for i in range(3,53): + pos[i] = (0,0,z) + size[i] = 2./d + color[i] = (0.0, 1.0, 0.0, 0.5) + z *= 0.5 + d *= 2.0 -#sp = gl.GLScatterPlotItem(pos=pos, sizes=size, colors=color, pxMode=False) +sp1 = gl.GLScatterPlotItem(pos=pos, size=size, color=color, pxMode=False) +sp1.translate(5,5,0) +w.addItem(sp1) -pos = (np.random.random(size=(100000,3)) * 10) - 5 +## +## Second example shows a volume of points with rapidly updating color +## and pxMode=True +## + +pos = np.random.random(size=(100000,3)) +pos *= [10,-10,10] +pos[0] = (0,0,0) color = np.ones((pos.shape[0], 4)) -d = (pos**2).sum(axis=1)**0.5 -color[:,3] = np.clip(-np.cos(d*2) * 0.2, 0, 1) -sp = gl.GLScatterPlotItem(pos=pos, color=color, size=5) +d2 = (pos**2).sum(axis=1)**0.5 +size = np.random.random(size=pos.shape[0])*10 +sp2 = gl.GLScatterPlotItem(pos=pos, color=(1,1,1,1), size=size) phase = 0. +w.addItem(sp2) + + +## +## Third example shows a grid of points with rapidly updating position +## and pxMode = False +## + +pos3 = np.zeros((100,100,3)) +pos3[:,:,:2] = np.mgrid[:100, :100].transpose(1,2,0) * [-0.1,0.1] +pos3 = pos3.reshape(10000,3) +d3 = (pos3**2).sum(axis=1)**0.5 + +sp3 = gl.GLScatterPlotItem(pos=pos3, color=(1,1,1,.3), size=0.1, pxMode=False) + +w.addItem(sp3) + + def update(): - global phase, color, sp, d - s = -np.cos(d*2+phase) - color[:,3] = np.clip(s * 0.2, 0, 1) + ## update volume colors + global phase, sp2, d2 + s = -np.cos(d2*2+phase) + color = np.empty((len(d2),4), dtype=np.float32) + color[:,3] = np.clip(s * 0.1, 0, 1) color[:,0] = np.clip(s * 3.0, 0, 1) color[:,1] = np.clip(s * 1.0, 0, 1) color[:,2] = np.clip(s ** 3, 0, 1) - - sp.setData(color=color) + sp2.setData(color=color) phase -= 0.1 + ## update surface positions and colors + global sp3, d3, pos3 + z = -np.cos(d3*2+phase) + pos3[:,2] = z + color = np.empty((len(d3),4), dtype=np.float32) + color[:,3] = 0.3 + color[:,0] = np.clip(z * 3.0, 0, 1) + color[:,1] = np.clip(z * 1.0, 0, 1) + color[:,2] = np.clip(z ** 3, 0, 1) + sp3.setData(pos=pos3, color=color) + t = QtCore.QTimer() t.timeout.connect(update) t.start(50) -w.addItem(sp) ## Start Qt event loop unless running in interactive mode. if sys.flags.interactive != 1: diff --git a/opengl/GLGraphicsItem.py b/opengl/GLGraphicsItem.py index 96cc6763..7d1cf70b 100644 --- a/opengl/GLGraphicsItem.py +++ b/opengl/GLGraphicsItem.py @@ -1,4 +1,5 @@ from pyqtgraph.Qt import QtGui, QtCore +from pyqtgraph import Transform3D class GLGraphicsItem(QtCore.QObject): def __init__(self, parentItem=None): @@ -6,7 +7,7 @@ class GLGraphicsItem(QtCore.QObject): self.__parent = None self.__view = None self.__children = set() - self.__transform = QtGui.QMatrix4x4() + self.__transform = Transform3D() self.__visible = True self.setParentItem(parentItem) self.setDepthValue(0) @@ -50,7 +51,7 @@ class GLGraphicsItem(QtCore.QObject): return self.__depthValue def setTransform(self, tr): - self.__transform = tr + self.__transform = Transform3D(tr) self.update() def resetTransform(self): @@ -73,12 +74,22 @@ class GLGraphicsItem(QtCore.QObject): def transform(self): return self.__transform + def viewTransform(self): + tr = self.__transform + p = self + while True: + p = p.parentItem() + if p is None: + break + tr = p.transform() * tr + return Transform3D(tr) + def translate(self, dx, dy, dz, local=False): """ Translate the object by (*dx*, *dy*, *dz*) in its parent's coordinate system. If *local* is True, then translation takes place in local coordinates. """ - tr = QtGui.QMatrix4x4() + tr = Transform3D() tr.translate(dx, dy, dz) self.applyTransform(tr, local=local) @@ -88,7 +99,7 @@ class GLGraphicsItem(QtCore.QObject): *angle* is in degrees. """ - tr = QtGui.QMatrix4x4() + tr = Transform3D() tr.rotate(angle, x, y, z) self.applyTransform(tr, local=local) @@ -97,7 +108,7 @@ class GLGraphicsItem(QtCore.QObject): Scale the object by (*dx*, *dy*, *dz*) in its local coordinate system. If *local* is False, then scale takes place in the parent's coordinates. """ - tr = QtGui.QMatrix4x4() + tr = Transform3D() tr.scale(x, y, z) self.applyTransform(tr, local=local) @@ -138,8 +149,29 @@ class GLGraphicsItem(QtCore.QObject): return v.updateGL() + def mapToParent(self, point): + tr = self.transform() + if tr is None: + return point + return tr.map(point) + def mapFromParent(self, point): tr = self.transform() if tr is None: return point - return tr.inverted()[0].map(point) \ No newline at end of file + return tr.inverted()[0].map(point) + + def mapToView(self, point): + tr = self.viewTransform() + if tr is None: + return point + return tr.map(point) + + def mapFromView(self, point): + tr = self.viewTransform() + if tr is None: + return point + return tr.inverted()[0].map(point) + + + \ No newline at end of file diff --git a/opengl/GLViewWidget.py b/opengl/GLViewWidget.py index 3e105491..6911d849 100644 --- a/opengl/GLViewWidget.py +++ b/opengl/GLViewWidget.py @@ -1,8 +1,8 @@ from pyqtgraph.Qt import QtCore, QtGui, QtOpenGL from OpenGL.GL import * import numpy as np - -Vector = QtGui.QVector3D +from pyqtgraph import Vector +##Vector = QtGui.QVector3D class GLViewWidget(QtOpenGL.QGLWidget): """ @@ -181,10 +181,14 @@ class GLViewWidget(QtOpenGL.QGLWidget): def pixelSize(self, pos): """ Return the approximate size of a screen pixel at the location pos - + Pos may be a Vector or an (N,3) array of locations """ cam = self.cameraPosition() - dist = (pos-cam).length() + if isinstance(pos, np.ndarray) and pos.ndim == 2: + cam = np.array(cam).reshape(1,3) + dist = ((pos-cam)**2).sum(axis=1)**0.5 + else: + dist = (pos-cam).length() xDist = dist * 2. * np.tan(0.5 * self.opts['fov'] * np.pi / 180.) return xDist / self.width() diff --git a/opengl/items/GLMeshItem.py b/opengl/items/GLMeshItem.py index 790c6760..266b84c0 100644 --- a/opengl/items/GLMeshItem.py +++ b/opengl/items/GLMeshItem.py @@ -28,7 +28,7 @@ class GLMeshItem(GLGraphicsItem): GLGraphicsItem.__init__(self) def initializeGL(self): - self.shader = shaders.getShader('balloon') + self.shader = shaders.getShaderProgram('balloon') l = glGenLists(1) self.triList = l @@ -72,7 +72,9 @@ class GLMeshItem(GLGraphicsItem): def paint(self): - shaders.glUseProgram(self.shader) - glCallList(self.triList) - shaders.glUseProgram(0) + with self.shader: + glCallList(self.triList) + #shaders.glUseProgram(self.shader) + #glCallList(self.triList) + #shaders.glUseProgram(0) #glCallList(self.meshList) diff --git a/opengl/items/GLScatterPlotItem.py b/opengl/items/GLScatterPlotItem.py index 3ef3f11b..1134ce51 100644 --- a/opengl/items/GLScatterPlotItem.py +++ b/opengl/items/GLScatterPlotItem.py @@ -1,5 +1,7 @@ from OpenGL.GL import * +from OpenGL.arrays import vbo from .. GLGraphicsItem import GLGraphicsItem +from .. import shaders from pyqtgraph import QtGui import numpy as np @@ -14,6 +16,7 @@ class GLScatterPlotItem(GLGraphicsItem): self.size = 10 self.color = [1.0,1.0,1.0,0.5] self.pxMode = True + #self.vbo = {} ## VBO does not appear to improve performance very much. self.setData(**kwds) def setData(self, **kwds): @@ -39,13 +42,16 @@ class GLScatterPlotItem(GLGraphicsItem): for k in kwds.keys(): if k not in args: raise Exception('Invalid keyword argument: %s (allowed arguments are %s)' % (k, str(args))) - self.pos = kwds.get('pos', self.pos) - self.color = kwds.get('color', self.color) - self.size = kwds.get('size', self.size) + + args.remove('pxMode') + for arg in args: + if arg in kwds: + setattr(self, arg, kwds[arg]) + #self.vbo.pop(arg, None) + self.pxMode = kwds.get('pxMode', self.pxMode) self.update() - def initializeGL(self): ## Generate texture for rendering points @@ -65,73 +71,105 @@ class GLScatterPlotItem(GLGraphicsItem): glBindTexture(GL_TEXTURE_2D, self.pointTexture) glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, pData.shape[0], pData.shape[1], 0, GL_RGBA, GL_UNSIGNED_BYTE, pData) - def paint(self): - glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) + self.shader = shaders.getShaderProgram('point_sprite') + + #def getVBO(self, name): + #if name not in self.vbo: + #self.vbo[name] = vbo.VBO(getattr(self, name).astype('f')) + #return self.vbo[name] + + def setupGLState(self): + """Prepare OpenGL state for drawing. This function is called immediately before painting.""" + #glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) ## requires z-sorting to render properly. + glBlendFunc(GL_SRC_ALPHA, GL_ONE) glEnable( GL_BLEND ) glEnable( GL_ALPHA_TEST ) - glEnable( GL_POINT_SMOOTH ) + glDisable( GL_DEPTH_TEST ) + + #glEnable( GL_POINT_SMOOTH ) - glHint(GL_POINT_SMOOTH_HINT, GL_NICEST) + #glHint(GL_POINT_SMOOTH_HINT, GL_NICEST) #glPointParameterfv(GL_POINT_DISTANCE_ATTENUATION, (0, 0, -1e-3)) #glPointParameterfv(GL_POINT_SIZE_MAX, (65500,)) #glPointParameterfv(GL_POINT_SIZE_MIN, (0,)) + def paint(self): + self.setupGLState() + glEnable(GL_POINT_SPRITE) + glActiveTexture(GL_TEXTURE0) glEnable( GL_TEXTURE_2D ) glBindTexture(GL_TEXTURE_2D, self.pointTexture) glTexEnvi(GL_POINT_SPRITE, GL_COORD_REPLACE, GL_TRUE) #glTexEnvi(GL_TEXTURE_ENV, GL_TEXTURE_ENV_MODE, GL_REPLACE) ## use texture color exactly - glTexEnvf( GL_TEXTURE_ENV, GL_TEXTURE_ENV_MODE, GL_MODULATE ) ## texture modulates current color + #glTexEnvf( GL_TEXTURE_ENV, GL_TEXTURE_ENV_MODE, GL_MODULATE ) ## texture modulates current color glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR) glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR) glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE) glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE) + glEnable(GL_PROGRAM_POINT_SIZE) - if self.pxMode: - glVertexPointerf(self.pos) - if isinstance(self.color, np.ndarray): - glColorPointerf(self.color) - else: - if isinstance(self.color, QtGui.QColor): - glColor4f(*fn.glColor(self.color)) - else: - glColor4f(*self.color) - if isinstance(self.size, np.ndarray): - raise Exception('Array size not yet supported in pxMode (hopefully soon)') - - glPointSize(self.size) + with self.shader: + #glUniform1i(self.shader.uniform('texture'), 0) ## inform the shader which texture to use glEnableClientState(GL_VERTEX_ARRAY) - glEnableClientState(GL_COLOR_ARRAY) - glDrawArrays(GL_POINTS, 0, len(self.pos)) - else: + try: + glVertexPointerf(self.pos) - - for i in range(len(self.pos)): - pos = self.pos[i] - if isinstance(self.color, np.ndarray): - color = self.color[i] + glEnableClientState(GL_COLOR_ARRAY) + glColorPointerf(self.color) else: - color = self.color - if isinstance(self.color, QtGui.QColor): - color = fn.glColor(self.color) - - if isinstance(self.size, np.ndarray): - size = self.size[i] - else: - size = self.size - - pxSize = self.view().pixelSize(QtGui.QVector3D(*pos)) + if isinstance(self.color, QtGui.QColor): + glColor4f(*fn.glColor(self.color)) + else: + glColor4f(*self.color) - glPointSize(size / pxSize) - glBegin( GL_POINTS ) - glColor4f(*color) # x is blue - #glNormal3f(size, 0, 0) - glVertex3f(*pos) - glEnd() + if not self.pxMode or isinstance(self.size, np.ndarray): + glEnableClientState(GL_NORMAL_ARRAY) + norm = np.empty(self.pos.shape) + if self.pxMode: + norm[:,0] = self.size + else: + gpos = self.mapToView(self.pos.transpose()).transpose() + pxSize = self.view().pixelSize(gpos) + norm[:,0] = self.size / pxSize + + glNormalPointerf(norm) + else: + glPointSize(self.size) + glDrawArrays(GL_POINTS, 0, len(self.pos)) + finally: + glDisableClientState(GL_NORMAL_ARRAY) + glDisableClientState(GL_VERTEX_ARRAY) + glDisableClientState(GL_COLOR_ARRAY) + #posVBO.unbind() + + #for i in range(len(self.pos)): + #pos = self.pos[i] + + #if isinstance(self.color, np.ndarray): + #color = self.color[i] + #else: + #color = self.color + #if isinstance(self.color, QtGui.QColor): + #color = fn.glColor(self.color) + + #if isinstance(self.size, np.ndarray): + #size = self.size[i] + #else: + #size = self.size + + #pxSize = self.view().pixelSize(QtGui.QVector3D(*pos)) + + #glPointSize(size / pxSize) + #glBegin( GL_POINTS ) + #glColor4f(*color) # x is blue + ##glNormal3f(size, 0, 0) + #glVertex3f(*pos) + #glEnd() diff --git a/opengl/shaders.py b/opengl/shaders.py index b1216e35..7f4fa665 100644 --- a/opengl/shaders.py +++ b/opengl/shaders.py @@ -3,39 +3,107 @@ from OpenGL.GL import shaders ## For centralizing and managing vertex/fragment shader programs. +def initShaders(): + global Shaders + Shaders = [ + ShaderProgram('balloon', [ ## increases fragment alpha as the normal turns orthogonal to the view + VertexShader(""" + varying vec3 normal; + void main() { + normal = normalize(gl_NormalMatrix * gl_Normal); + //vec4 color = normal; + //normal.w = min(color.w + 2.0 * color.w * pow(normal.x*normal.x + normal.y*normal.y, 2.0), 1.0); + gl_FrontColor = gl_Color; + gl_BackColor = gl_Color; + gl_Position = ftransform(); + } + """), + FragmentShader(""" + varying vec3 normal; + void main() { + vec4 color = gl_Color; + color.w = min(color.w + 2.0 * color.w * pow(normal.x*normal.x + normal.y*normal.y, 5.0), 1.0); + gl_FragColor = color; + } + """) + ]), + ShaderProgram('point_sprite', [ ## allows specifying point size using normal.x + ## See: + ## + ## http://stackoverflow.com/questions/9609423/applying-part-of-a-texture-sprite-sheet-texture-map-to-a-point-sprite-in-ios + ## http://stackoverflow.com/questions/3497068/textured-points-in-opengl-es-2-0 + ## + ## + VertexShader(""" + void main() { + gl_FrontColor=gl_Color; + gl_PointSize = gl_Normal.x; + gl_Position = ftransform(); + } + """), + #FragmentShader(""" + ##version 120 + #uniform sampler2D texture; + #void main ( ) + #{ + #gl_FragColor = texture2D(texture, gl_PointCoord) * gl_Color; + #} + #""") + ]), + ] -Shaders = { - 'balloon': ( ## increases fragment alpha as the normal turns orthogonal to the view - """ - varying vec3 normal; - void main() { - normal = normalize(gl_NormalMatrix * gl_Normal); - //vec4 color = normal; - //normal.w = min(color.w + 2.0 * color.w * pow(normal.x*normal.x + normal.y*normal.y, 2.0), 1.0); - gl_FrontColor = gl_Color; - gl_BackColor = gl_Color; - gl_Position = ftransform(); - } - """, - """ - varying vec3 normal; - void main() { - vec4 color = gl_Color; - color.w = min(color.w + 2.0 * color.w * pow(normal.x*normal.x + normal.y*normal.y, 5.0), 1.0); - gl_FragColor = color; - } - """ - ), -} -CompiledShaders = {} + +CompiledShaderPrograms = {} -def getShader(name): - global Shaders, CompiledShaders +def getShaderProgram(name): + return ShaderProgram.names[name] + +class VertexShader: + def __init__(self, code): + self.code = code + self.compiled = None + + def shader(self): + if self.compiled is None: + self.compiled = shaders.compileShader(self.code, GL_VERTEX_SHADER) + return self.compiled + +class FragmentShader: + def __init__(self, code): + self.code = code + self.compiled = None + + def shader(self): + if self.compiled is None: + self.compiled = shaders.compileShader(self.code, GL_FRAGMENT_SHADER) + return self.compiled + + + +class ShaderProgram: + names = {} - if name not in CompiledShaders: - vshader, fshader = Shaders[name] - vcomp = shaders.compileShader(vshader, GL_VERTEX_SHADER) - fcomp = shaders.compileShader(fshader, GL_FRAGMENT_SHADER) - prog = shaders.compileProgram(vcomp, fcomp) - CompiledShaders[name] = prog, vcomp, fcomp - return CompiledShaders[name][0] + def __init__(self, name, shaders): + self.name = name + ShaderProgram.names[name] = self + self.shaders = shaders + self.prog = None + + def program(self): + if self.prog is None: + compiled = [s.shader() for s in self.shaders] ## compile all shaders + self.prog = shaders.compileProgram(*compiled) ## compile program + return self.prog + + def __enter__(self): + glUseProgram(self.program()) + + def __exit__(self, *args): + glUseProgram(0) + + def uniform(self, name): + """Return the location integer for a uniform variable in this program""" + return glGetUniformLocation(self.program(), name) + + +initShaders() \ No newline at end of file