GLScatterPlotItem: use shader programs to allow specifying spot size by array

Reorganized shader programs
Infrastructure updates for OpenGL system
This commit is contained in:
Luke Campagnola 2012-10-26 21:47:45 -04:00
parent 450626a3bb
commit b09182d19a
10 changed files with 411 additions and 183 deletions

View File

@ -6,12 +6,12 @@ import pyqtgraph as pg
import numpy as np
import scipy.linalg
class SRTTransform3D(QtGui.QMatrix4x4):
class SRTTransform3D(pg.Transform3D):
"""4x4 Transform matrix that can always be represented as a combination of 3 matrices: scale * rotate * translate
This transform has no shear; angles are always preserved.
"""
def __init__(self, init=None):
QtGui.QMatrix4x4.__init__(self)
pg.Transform3D.__init__(self)
self.reset()
if init is None:
return
@ -190,11 +190,11 @@ class SRTTransform3D(QtGui.QMatrix4x4):
self.update()
def update(self):
QtGui.QMatrix4x4.setToIdentity(self)
pg.Transform3D.setToIdentity(self)
## modifications to the transform are multiplied on the right, so we need to reverse order here.
QtGui.QMatrix4x4.translate(self, *self._state['pos'])
QtGui.QMatrix4x4.rotate(self, self._state['angle'], *self._state['axis'])
QtGui.QMatrix4x4.scale(self, *self._state['scale'])
pg.Transform3D.translate(self, *self._state['pos'])
pg.Transform3D.rotate(self, self._state['angle'], *self._state['axis'])
pg.Transform3D.scale(self, *self._state['scale'])
def __repr__(self):
return str(self.saveState())

35
Transform3D.py Normal file
View File

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
from .Qt import QtCore, QtGui
import pyqtgraph as pg
import numpy as np
class Transform3D(QtGui.QMatrix4x4):
"""
Extension of QMatrix4x4 with some helpful methods added.
"""
def __init__(self, *args):
QtGui.QMatrix4x4.__init__(self, *args)
def matrix(self, nd=3):
if nd == 3:
return np.array(self.copyDataTo()).reshape(4,4)
elif nd == 2:
m = np.array(self.copyDataTo()).reshape(4,4)
m[2] = m[3]
m[:,2] = m[:,3]
return m[:3,:3]
else:
raise Exception("Argument 'nd' must be 2 or 3")
def map(self, obj):
"""
Extends QMatrix4x4.map() to allow mapping (3, ...) arrays of coordinates
"""
if isinstance(obj, np.ndarray) and obj.ndim >= 2 and obj.shape[0] in (2,3):
return pg.transformCoordinates(self, obj)
else:
return QtGui.QMatrix4x4.map(self, obj)
def inverted(self):
inv, b = QtGui.QMatrix4x4.inverted(self)
return Transform3D(inv), b

123
Vector.py
View File

@ -1,59 +1,64 @@
# -*- coding: utf-8 -*-
"""
Vector.py - Extension of QVector3D which adds a few missing methods.
Copyright 2010 Luke Campagnola
Distributed under MIT/X11 license. See license.txt for more infomation.
"""
from .Qt import QtGui, QtCore
import numpy as np
class Vector(QtGui.QVector3D):
"""Extension of QVector3D which adds a few helpful methods."""
def __init__(self, *args):
if len(args) == 1:
if isinstance(args[0], QtCore.QSizeF):
QtGui.QVector3D.__init__(self, float(args[0].width()), float(args[0].height()), 0)
return
elif isinstance(args[0], QtCore.QPoint) or isinstance(args[0], QtCore.QPointF):
QtGui.QVector3D.__init__(self, float(args[0].x()), float(args[0].y()), 0)
elif hasattr(args[0], '__getitem__'):
vals = list(args[0])
if len(vals) == 2:
vals.append(0)
if len(vals) != 3:
raise Exception('Cannot init Vector with sequence of length %d' % len(args[0]))
QtGui.QVector3D.__init__(self, *vals)
return
elif len(args) == 2:
QtGui.QVector3D.__init__(self, args[0], args[1], 0)
return
QtGui.QVector3D.__init__(self, *args)
def __len__(self):
return 3
#def __reduce__(self):
#return (Point, (self.x(), self.y()))
def __getitem__(self, i):
if i == 0:
return self.x()
elif i == 1:
return self.y()
elif i == 2:
return self.z()
else:
raise IndexError("Point has no index %s" % str(i))
def __setitem__(self, i, x):
if i == 0:
return self.setX(x)
elif i == 1:
return self.setY(x)
elif i == 2:
return self.setZ(x)
else:
raise IndexError("Point has no index %s" % str(i))
# -*- coding: utf-8 -*-
"""
Vector.py - Extension of QVector3D which adds a few missing methods.
Copyright 2010 Luke Campagnola
Distributed under MIT/X11 license. See license.txt for more infomation.
"""
from .Qt import QtGui, QtCore
import numpy as np
class Vector(QtGui.QVector3D):
"""Extension of QVector3D which adds a few helpful methods."""
def __init__(self, *args):
if len(args) == 1:
if isinstance(args[0], QtCore.QSizeF):
QtGui.QVector3D.__init__(self, float(args[0].width()), float(args[0].height()), 0)
return
elif isinstance(args[0], QtCore.QPoint) or isinstance(args[0], QtCore.QPointF):
QtGui.QVector3D.__init__(self, float(args[0].x()), float(args[0].y()), 0)
elif hasattr(args[0], '__getitem__'):
vals = list(args[0])
if len(vals) == 2:
vals.append(0)
if len(vals) != 3:
raise Exception('Cannot init Vector with sequence of length %d' % len(args[0]))
QtGui.QVector3D.__init__(self, *vals)
return
elif len(args) == 2:
QtGui.QVector3D.__init__(self, args[0], args[1], 0)
return
QtGui.QVector3D.__init__(self, *args)
def __len__(self):
return 3
#def __reduce__(self):
#return (Point, (self.x(), self.y()))
def __getitem__(self, i):
if i == 0:
return self.x()
elif i == 1:
return self.y()
elif i == 2:
return self.z()
else:
raise IndexError("Point has no index %s" % str(i))
def __setitem__(self, i, x):
if i == 0:
return self.setX(x)
elif i == 1:
return self.setY(x)
elif i == 2:
return self.setZ(x)
else:
raise IndexError("Point has no index %s" % str(i))
def __iter__(self):
yield(self.x())
yield(self.y())
yield(self.z())

View File

@ -165,6 +165,7 @@ from .WidgetGroup import *
from .Point import Point
from .Vector import Vector
from .SRTTransform import SRTTransform
from .Transform3D import Transform3D
from .SRTTransform3D import SRTTransform3D
from .functions import *
from .graphicsWindows import *

View File

@ -15,48 +15,91 @@ w.show()
g = gl.GLGridItem()
w.addItem(g)
#pos = np.empty((53, 3))
#size = np.empty((53))
#color = np.empty((53, 4))
#pos[0] = (1,0,0); size[0] = 0.5; color[0] = (1.0, 0.0, 0.0, 0.5)
#pos[1] = (0,1,0); size[1] = 0.2; color[1] = (0.0, 0.0, 1.0, 0.5)
#pos[2] = (0,0,1); size[2] = 2./3.; color[2] = (0.0, 1.0, 0.0, 0.5)
#z = 0.5
#d = 6.0
#for i in range(3,53):
#pos[i] = (0,0,z)
#size[i] = 2./d
#color[i] = (0.0, 1.0, 0.0, 0.5)
#z *= 0.5
#d *= 2.0
##
## First example is a set of points with pxMode=False
## These demonstrate the ability to have points with real size down to a very small scale
##
pos = np.empty((53, 3))
size = np.empty((53))
color = np.empty((53, 4))
pos[0] = (1,0,0); size[0] = 0.5; color[0] = (1.0, 0.0, 0.0, 0.5)
pos[1] = (0,1,0); size[1] = 0.2; color[1] = (0.0, 0.0, 1.0, 0.5)
pos[2] = (0,0,1); size[2] = 2./3.; color[2] = (0.0, 1.0, 0.0, 0.5)
z = 0.5
d = 6.0
for i in range(3,53):
pos[i] = (0,0,z)
size[i] = 2./d
color[i] = (0.0, 1.0, 0.0, 0.5)
z *= 0.5
d *= 2.0
#sp = gl.GLScatterPlotItem(pos=pos, sizes=size, colors=color, pxMode=False)
sp1 = gl.GLScatterPlotItem(pos=pos, size=size, color=color, pxMode=False)
sp1.translate(5,5,0)
w.addItem(sp1)
pos = (np.random.random(size=(100000,3)) * 10) - 5
##
## Second example shows a volume of points with rapidly updating color
## and pxMode=True
##
pos = np.random.random(size=(100000,3))
pos *= [10,-10,10]
pos[0] = (0,0,0)
color = np.ones((pos.shape[0], 4))
d = (pos**2).sum(axis=1)**0.5
color[:,3] = np.clip(-np.cos(d*2) * 0.2, 0, 1)
sp = gl.GLScatterPlotItem(pos=pos, color=color, size=5)
d2 = (pos**2).sum(axis=1)**0.5
size = np.random.random(size=pos.shape[0])*10
sp2 = gl.GLScatterPlotItem(pos=pos, color=(1,1,1,1), size=size)
phase = 0.
w.addItem(sp2)
##
## Third example shows a grid of points with rapidly updating position
## and pxMode = False
##
pos3 = np.zeros((100,100,3))
pos3[:,:,:2] = np.mgrid[:100, :100].transpose(1,2,0) * [-0.1,0.1]
pos3 = pos3.reshape(10000,3)
d3 = (pos3**2).sum(axis=1)**0.5
sp3 = gl.GLScatterPlotItem(pos=pos3, color=(1,1,1,.3), size=0.1, pxMode=False)
w.addItem(sp3)
def update():
global phase, color, sp, d
s = -np.cos(d*2+phase)
color[:,3] = np.clip(s * 0.2, 0, 1)
## update volume colors
global phase, sp2, d2
s = -np.cos(d2*2+phase)
color = np.empty((len(d2),4), dtype=np.float32)
color[:,3] = np.clip(s * 0.1, 0, 1)
color[:,0] = np.clip(s * 3.0, 0, 1)
color[:,1] = np.clip(s * 1.0, 0, 1)
color[:,2] = np.clip(s ** 3, 0, 1)
sp.setData(color=color)
sp2.setData(color=color)
phase -= 0.1
## update surface positions and colors
global sp3, d3, pos3
z = -np.cos(d3*2+phase)
pos3[:,2] = z
color = np.empty((len(d3),4), dtype=np.float32)
color[:,3] = 0.3
color[:,0] = np.clip(z * 3.0, 0, 1)
color[:,1] = np.clip(z * 1.0, 0, 1)
color[:,2] = np.clip(z ** 3, 0, 1)
sp3.setData(pos=pos3, color=color)
t = QtCore.QTimer()
t.timeout.connect(update)
t.start(50)
w.addItem(sp)
## Start Qt event loop unless running in interactive mode.
if sys.flags.interactive != 1:

View File

@ -1,4 +1,5 @@
from pyqtgraph.Qt import QtGui, QtCore
from pyqtgraph import Transform3D
class GLGraphicsItem(QtCore.QObject):
def __init__(self, parentItem=None):
@ -6,7 +7,7 @@ class GLGraphicsItem(QtCore.QObject):
self.__parent = None
self.__view = None
self.__children = set()
self.__transform = QtGui.QMatrix4x4()
self.__transform = Transform3D()
self.__visible = True
self.setParentItem(parentItem)
self.setDepthValue(0)
@ -50,7 +51,7 @@ class GLGraphicsItem(QtCore.QObject):
return self.__depthValue
def setTransform(self, tr):
self.__transform = tr
self.__transform = Transform3D(tr)
self.update()
def resetTransform(self):
@ -73,12 +74,22 @@ class GLGraphicsItem(QtCore.QObject):
def transform(self):
return self.__transform
def viewTransform(self):
tr = self.__transform
p = self
while True:
p = p.parentItem()
if p is None:
break
tr = p.transform() * tr
return Transform3D(tr)
def translate(self, dx, dy, dz, local=False):
"""
Translate the object by (*dx*, *dy*, *dz*) in its parent's coordinate system.
If *local* is True, then translation takes place in local coordinates.
"""
tr = QtGui.QMatrix4x4()
tr = Transform3D()
tr.translate(dx, dy, dz)
self.applyTransform(tr, local=local)
@ -88,7 +99,7 @@ class GLGraphicsItem(QtCore.QObject):
*angle* is in degrees.
"""
tr = QtGui.QMatrix4x4()
tr = Transform3D()
tr.rotate(angle, x, y, z)
self.applyTransform(tr, local=local)
@ -97,7 +108,7 @@ class GLGraphicsItem(QtCore.QObject):
Scale the object by (*dx*, *dy*, *dz*) in its local coordinate system.
If *local* is False, then scale takes place in the parent's coordinates.
"""
tr = QtGui.QMatrix4x4()
tr = Transform3D()
tr.scale(x, y, z)
self.applyTransform(tr, local=local)
@ -138,8 +149,29 @@ class GLGraphicsItem(QtCore.QObject):
return
v.updateGL()
def mapToParent(self, point):
tr = self.transform()
if tr is None:
return point
return tr.map(point)
def mapFromParent(self, point):
tr = self.transform()
if tr is None:
return point
return tr.inverted()[0].map(point)
return tr.inverted()[0].map(point)
def mapToView(self, point):
tr = self.viewTransform()
if tr is None:
return point
return tr.map(point)
def mapFromView(self, point):
tr = self.viewTransform()
if tr is None:
return point
return tr.inverted()[0].map(point)

View File

@ -1,8 +1,8 @@
from pyqtgraph.Qt import QtCore, QtGui, QtOpenGL
from OpenGL.GL import *
import numpy as np
Vector = QtGui.QVector3D
from pyqtgraph import Vector
##Vector = QtGui.QVector3D
class GLViewWidget(QtOpenGL.QGLWidget):
"""
@ -181,10 +181,14 @@ class GLViewWidget(QtOpenGL.QGLWidget):
def pixelSize(self, pos):
"""
Return the approximate size of a screen pixel at the location pos
Pos may be a Vector or an (N,3) array of locations
"""
cam = self.cameraPosition()
dist = (pos-cam).length()
if isinstance(pos, np.ndarray) and pos.ndim == 2:
cam = np.array(cam).reshape(1,3)
dist = ((pos-cam)**2).sum(axis=1)**0.5
else:
dist = (pos-cam).length()
xDist = dist * 2. * np.tan(0.5 * self.opts['fov'] * np.pi / 180.)
return xDist / self.width()

View File

@ -28,7 +28,7 @@ class GLMeshItem(GLGraphicsItem):
GLGraphicsItem.__init__(self)
def initializeGL(self):
self.shader = shaders.getShader('balloon')
self.shader = shaders.getShaderProgram('balloon')
l = glGenLists(1)
self.triList = l
@ -72,7 +72,9 @@ class GLMeshItem(GLGraphicsItem):
def paint(self):
shaders.glUseProgram(self.shader)
glCallList(self.triList)
shaders.glUseProgram(0)
with self.shader:
glCallList(self.triList)
#shaders.glUseProgram(self.shader)
#glCallList(self.triList)
#shaders.glUseProgram(0)
#glCallList(self.meshList)

View File

@ -1,5 +1,7 @@
from OpenGL.GL import *
from OpenGL.arrays import vbo
from .. GLGraphicsItem import GLGraphicsItem
from .. import shaders
from pyqtgraph import QtGui
import numpy as np
@ -14,6 +16,7 @@ class GLScatterPlotItem(GLGraphicsItem):
self.size = 10
self.color = [1.0,1.0,1.0,0.5]
self.pxMode = True
#self.vbo = {} ## VBO does not appear to improve performance very much.
self.setData(**kwds)
def setData(self, **kwds):
@ -39,13 +42,16 @@ class GLScatterPlotItem(GLGraphicsItem):
for k in kwds.keys():
if k not in args:
raise Exception('Invalid keyword argument: %s (allowed arguments are %s)' % (k, str(args)))
self.pos = kwds.get('pos', self.pos)
self.color = kwds.get('color', self.color)
self.size = kwds.get('size', self.size)
args.remove('pxMode')
for arg in args:
if arg in kwds:
setattr(self, arg, kwds[arg])
#self.vbo.pop(arg, None)
self.pxMode = kwds.get('pxMode', self.pxMode)
self.update()
def initializeGL(self):
## Generate texture for rendering points
@ -65,73 +71,105 @@ class GLScatterPlotItem(GLGraphicsItem):
glBindTexture(GL_TEXTURE_2D, self.pointTexture)
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, pData.shape[0], pData.shape[1], 0, GL_RGBA, GL_UNSIGNED_BYTE, pData)
def paint(self):
glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
self.shader = shaders.getShaderProgram('point_sprite')
#def getVBO(self, name):
#if name not in self.vbo:
#self.vbo[name] = vbo.VBO(getattr(self, name).astype('f'))
#return self.vbo[name]
def setupGLState(self):
"""Prepare OpenGL state for drawing. This function is called immediately before painting."""
#glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) ## requires z-sorting to render properly.
glBlendFunc(GL_SRC_ALPHA, GL_ONE)
glEnable( GL_BLEND )
glEnable( GL_ALPHA_TEST )
glEnable( GL_POINT_SMOOTH )
glDisable( GL_DEPTH_TEST )
#glEnable( GL_POINT_SMOOTH )
glHint(GL_POINT_SMOOTH_HINT, GL_NICEST)
#glHint(GL_POINT_SMOOTH_HINT, GL_NICEST)
#glPointParameterfv(GL_POINT_DISTANCE_ATTENUATION, (0, 0, -1e-3))
#glPointParameterfv(GL_POINT_SIZE_MAX, (65500,))
#glPointParameterfv(GL_POINT_SIZE_MIN, (0,))
def paint(self):
self.setupGLState()
glEnable(GL_POINT_SPRITE)
glActiveTexture(GL_TEXTURE0)
glEnable( GL_TEXTURE_2D )
glBindTexture(GL_TEXTURE_2D, self.pointTexture)
glTexEnvi(GL_POINT_SPRITE, GL_COORD_REPLACE, GL_TRUE)
#glTexEnvi(GL_TEXTURE_ENV, GL_TEXTURE_ENV_MODE, GL_REPLACE) ## use texture color exactly
glTexEnvf( GL_TEXTURE_ENV, GL_TEXTURE_ENV_MODE, GL_MODULATE ) ## texture modulates current color
#glTexEnvf( GL_TEXTURE_ENV, GL_TEXTURE_ENV_MODE, GL_MODULATE ) ## texture modulates current color
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR)
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR)
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE)
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE)
glEnable(GL_PROGRAM_POINT_SIZE)
if self.pxMode:
glVertexPointerf(self.pos)
if isinstance(self.color, np.ndarray):
glColorPointerf(self.color)
else:
if isinstance(self.color, QtGui.QColor):
glColor4f(*fn.glColor(self.color))
else:
glColor4f(*self.color)
if isinstance(self.size, np.ndarray):
raise Exception('Array size not yet supported in pxMode (hopefully soon)')
glPointSize(self.size)
with self.shader:
#glUniform1i(self.shader.uniform('texture'), 0) ## inform the shader which texture to use
glEnableClientState(GL_VERTEX_ARRAY)
glEnableClientState(GL_COLOR_ARRAY)
glDrawArrays(GL_POINTS, 0, len(self.pos))
else:
try:
glVertexPointerf(self.pos)
for i in range(len(self.pos)):
pos = self.pos[i]
if isinstance(self.color, np.ndarray):
color = self.color[i]
glEnableClientState(GL_COLOR_ARRAY)
glColorPointerf(self.color)
else:
color = self.color
if isinstance(self.color, QtGui.QColor):
color = fn.glColor(self.color)
if isinstance(self.size, np.ndarray):
size = self.size[i]
else:
size = self.size
pxSize = self.view().pixelSize(QtGui.QVector3D(*pos))
if isinstance(self.color, QtGui.QColor):
glColor4f(*fn.glColor(self.color))
else:
glColor4f(*self.color)
glPointSize(size / pxSize)
glBegin( GL_POINTS )
glColor4f(*color) # x is blue
#glNormal3f(size, 0, 0)
glVertex3f(*pos)
glEnd()
if not self.pxMode or isinstance(self.size, np.ndarray):
glEnableClientState(GL_NORMAL_ARRAY)
norm = np.empty(self.pos.shape)
if self.pxMode:
norm[:,0] = self.size
else:
gpos = self.mapToView(self.pos.transpose()).transpose()
pxSize = self.view().pixelSize(gpos)
norm[:,0] = self.size / pxSize
glNormalPointerf(norm)
else:
glPointSize(self.size)
glDrawArrays(GL_POINTS, 0, len(self.pos))
finally:
glDisableClientState(GL_NORMAL_ARRAY)
glDisableClientState(GL_VERTEX_ARRAY)
glDisableClientState(GL_COLOR_ARRAY)
#posVBO.unbind()
#for i in range(len(self.pos)):
#pos = self.pos[i]
#if isinstance(self.color, np.ndarray):
#color = self.color[i]
#else:
#color = self.color
#if isinstance(self.color, QtGui.QColor):
#color = fn.glColor(self.color)
#if isinstance(self.size, np.ndarray):
#size = self.size[i]
#else:
#size = self.size
#pxSize = self.view().pixelSize(QtGui.QVector3D(*pos))
#glPointSize(size / pxSize)
#glBegin( GL_POINTS )
#glColor4f(*color) # x is blue
##glNormal3f(size, 0, 0)
#glVertex3f(*pos)
#glEnd()

View File

@ -3,39 +3,107 @@ from OpenGL.GL import shaders
## For centralizing and managing vertex/fragment shader programs.
def initShaders():
global Shaders
Shaders = [
ShaderProgram('balloon', [ ## increases fragment alpha as the normal turns orthogonal to the view
VertexShader("""
varying vec3 normal;
void main() {
normal = normalize(gl_NormalMatrix * gl_Normal);
//vec4 color = normal;
//normal.w = min(color.w + 2.0 * color.w * pow(normal.x*normal.x + normal.y*normal.y, 2.0), 1.0);
gl_FrontColor = gl_Color;
gl_BackColor = gl_Color;
gl_Position = ftransform();
}
"""),
FragmentShader("""
varying vec3 normal;
void main() {
vec4 color = gl_Color;
color.w = min(color.w + 2.0 * color.w * pow(normal.x*normal.x + normal.y*normal.y, 5.0), 1.0);
gl_FragColor = color;
}
""")
]),
ShaderProgram('point_sprite', [ ## allows specifying point size using normal.x
## See:
##
## http://stackoverflow.com/questions/9609423/applying-part-of-a-texture-sprite-sheet-texture-map-to-a-point-sprite-in-ios
## http://stackoverflow.com/questions/3497068/textured-points-in-opengl-es-2-0
##
##
VertexShader("""
void main() {
gl_FrontColor=gl_Color;
gl_PointSize = gl_Normal.x;
gl_Position = ftransform();
}
"""),
#FragmentShader("""
##version 120
#uniform sampler2D texture;
#void main ( )
#{
#gl_FragColor = texture2D(texture, gl_PointCoord) * gl_Color;
#}
#""")
]),
]
Shaders = {
'balloon': ( ## increases fragment alpha as the normal turns orthogonal to the view
"""
varying vec3 normal;
void main() {
normal = normalize(gl_NormalMatrix * gl_Normal);
//vec4 color = normal;
//normal.w = min(color.w + 2.0 * color.w * pow(normal.x*normal.x + normal.y*normal.y, 2.0), 1.0);
gl_FrontColor = gl_Color;
gl_BackColor = gl_Color;
gl_Position = ftransform();
}
""",
"""
varying vec3 normal;
void main() {
vec4 color = gl_Color;
color.w = min(color.w + 2.0 * color.w * pow(normal.x*normal.x + normal.y*normal.y, 5.0), 1.0);
gl_FragColor = color;
}
"""
),
}
CompiledShaders = {}
CompiledShaderPrograms = {}
def getShader(name):
global Shaders, CompiledShaders
def getShaderProgram(name):
return ShaderProgram.names[name]
class VertexShader:
def __init__(self, code):
self.code = code
self.compiled = None
def shader(self):
if self.compiled is None:
self.compiled = shaders.compileShader(self.code, GL_VERTEX_SHADER)
return self.compiled
class FragmentShader:
def __init__(self, code):
self.code = code
self.compiled = None
def shader(self):
if self.compiled is None:
self.compiled = shaders.compileShader(self.code, GL_FRAGMENT_SHADER)
return self.compiled
class ShaderProgram:
names = {}
if name not in CompiledShaders:
vshader, fshader = Shaders[name]
vcomp = shaders.compileShader(vshader, GL_VERTEX_SHADER)
fcomp = shaders.compileShader(fshader, GL_FRAGMENT_SHADER)
prog = shaders.compileProgram(vcomp, fcomp)
CompiledShaders[name] = prog, vcomp, fcomp
return CompiledShaders[name][0]
def __init__(self, name, shaders):
self.name = name
ShaderProgram.names[name] = self
self.shaders = shaders
self.prog = None
def program(self):
if self.prog is None:
compiled = [s.shader() for s in self.shaders] ## compile all shaders
self.prog = shaders.compileProgram(*compiled) ## compile program
return self.prog
def __enter__(self):
glUseProgram(self.program())
def __exit__(self, *args):
glUseProgram(0)
def uniform(self, name):
"""Return the location integer for a uniform variable in this program"""
return glGetUniformLocation(self.program(), name)
initShaders()