schema: add paramref
[free-sw/xcb/proto] / xcbgen / expr.py
index 51e738f..e6895ff 100644 (file)
@@ -53,6 +53,8 @@ class Expression(object):
         self.lhs = None
         self.rhs = None
 
+        self.contains_listelement_ref = False
+
         if elt.tag == 'list':
             # List going into a request, which has no length field (inferred by server)
             self.lenfield_name = elt.get('name') + '_len'
@@ -62,6 +64,10 @@ class Expression(object):
             # Standard list with a fieldref
             self.lenfield_name = elt.text
 
+        elif elt.tag == 'paramref':
+            self.lenfield_name = elt.text
+            self.lenfield_type = elt.get('type')
+
         elif elt.tag == 'valueparam':
             # Value-mask.  The length bitmask is described by attributes.
             self.lenfield_name = elt.get('value-mask-name')
@@ -104,6 +110,17 @@ class Expression(object):
         elif elt.tag == 'sumof':
             self.op = 'sumof'
             self.lenfield_name = elt.get('ref')
+            subexpressions = list(elt)
+            if len(subexpressions) > 0:
+                # sumof with a nested expression which is to be evaluated
+                # for each list-element in the context of that list-element.
+                # sumof then returns the sum of the results of these evaluations
+                self.rhs = Expression(subexpressions[0], parent)
+
+        elif elt.tag == 'listelement-ref':
+            # current list element inside iterating expressions such as sumof
+            self.op = 'listelement-ref'
+            self.contains_listelement_ref = True
 
         else:
             # Notreached
@@ -112,6 +129,12 @@ class Expression(object):
     def fixed_size(self):
         return self.nmemb != None
 
+    def recursive_resolve_tasks(self, module, parents):
+        for subexpr in (self.lhs, self.rhs):
+            if subexpr != None:
+                subexpr.recursive_resolve_tasks(module, parents)
+                self.contains_listelement_ref |= subexpr.contains_listelement_ref
+
     def resolve(self, module, parents):
         if self.op == 'enumref':
             self.lenfield_type = module.get_type(self.lenfield_name[0])
@@ -128,4 +151,6 @@ class Expression(object):
                         self.lenfield_parent = p
                     self.lenfield_type = fields[self.lenfield_name].field_type
                     break
+
+        self.recursive_resolve_tasks(module, parents)