[gtk+/wip/otte/shader: 38/98] gskslexpression: Move division to the new binary vfuncs



commit 22ce96f0f9b8503cf78d0dfab5175a01bb61b283
Author: Benjamin Otte <otte redhat com>
Date:   Sun Oct 8 04:19:29 2017 +0200

    gskslexpression: Move division to the new binary vfuncs

 gsk/gskslbinary.c     |  378 +++++++++++++++++++++++++++++++
 gsk/gskslexpression.c |  602 ++++++-------------------------------------------
 2 files changed, 445 insertions(+), 535 deletions(-)
---
diff --git a/gsk/gskslbinary.c b/gsk/gskslbinary.c
index 57ee881..e7a2f71 100644
--- a/gsk/gskslbinary.c
+++ b/gsk/gskslbinary.c
@@ -413,6 +413,372 @@ static const GskSlBinary GSK_SL_BINARY_MULTIPLICATION = {
   gsk_sl_multiplication_write_spv
 };
 
+/* DIVISION */
+
+static GskSlType *
+gsk_sl_arithmetic_check_type (GskSlPreprocessor *preproc,
+                              GskSlType         *ltype,
+                              GskSlType         *rtype)
+{
+  GskSlScalarType scalar;
+
+  if (gsk_sl_scalar_type_can_convert (gsk_sl_type_get_scalar_type (ltype),
+                                      gsk_sl_type_get_scalar_type (rtype)))
+    scalar = gsk_sl_type_get_scalar_type (ltype);
+  else if (gsk_sl_scalar_type_can_convert (gsk_sl_type_get_scalar_type (rtype),
+                                           gsk_sl_type_get_scalar_type (ltype)))
+    scalar = gsk_sl_type_get_scalar_type (rtype);
+  else
+    {
+      gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
+                                 "Operand types %s and %s do not share compatible scalar types.",
+                                 gsk_sl_type_get_name (ltype), gsk_sl_type_get_name (rtype));
+      return NULL;
+    }
+
+  if (gsk_sl_type_is_matrix (ltype))
+    {
+      if (gsk_sl_type_is_matrix (rtype))
+        {
+          if (gsk_sl_type_can_convert (ltype, rtype))
+            {
+              return ltype;
+            }
+          else if (gsk_sl_type_can_convert (rtype, ltype))
+            {
+              return rtype;
+            }
+          else
+            {
+              gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
+                                         "Matrix types %s and %s have different size.",
+                                         gsk_sl_type_get_name (ltype), gsk_sl_type_get_name (rtype));
+              return NULL;
+            }
+        }
+      else if (gsk_sl_type_is_vector (rtype))
+        {
+          gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
+                                     "Cannot perform arithmetic arithmetic between matrix and vector.");
+          return NULL;
+        }
+      else if (gsk_sl_type_is_scalar (rtype))
+        {
+          return gsk_sl_type_get_matrix (scalar,
+                                         gsk_sl_type_get_length (ltype),
+                                         gsk_sl_type_get_length (gsk_sl_type_get_index_type (ltype)));
+        }
+      else
+        {
+          gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
+                                     "Right operand is incompatible type for arithemtic arithmetic.");
+          return NULL;
+        }
+    }
+  else if (gsk_sl_type_is_vector (ltype))
+    {
+      if (gsk_sl_type_is_matrix (rtype))
+        {
+          gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH, "Cannot perform arithmetic arithmetic between 
vector and matrix.");
+          return NULL;
+        }
+      else if (gsk_sl_type_is_vector (rtype))
+        {
+          if (gsk_sl_type_get_length (ltype) != gsk_sl_type_get_length (rtype))
+            {
+              gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
+                                         "Vector operands %s and %s to arithmetic arithmetic have different 
length.",
+                                         gsk_sl_type_get_name (ltype), gsk_sl_type_get_name (rtype));
+              return NULL;
+            }
+          return gsk_sl_type_get_vector (scalar, gsk_sl_type_get_length (ltype));
+        }
+      else if (gsk_sl_type_is_scalar (rtype))
+        {
+          return gsk_sl_type_get_vector (scalar,
+                                         gsk_sl_type_get_length (ltype));
+        }
+      else
+        {
+          gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
+                                     "Right operand is incompatible type for arithemtic arithmetic.");
+          return NULL;
+        }
+    }
+  else if (gsk_sl_type_is_scalar (ltype))
+    {
+      if (gsk_sl_type_is_matrix (rtype))
+        {
+          return gsk_sl_type_get_matrix (scalar,
+                                         gsk_sl_type_get_length (rtype),
+                                         gsk_sl_type_get_length (gsk_sl_type_get_index_type (rtype)));
+        }
+      else if (gsk_sl_type_is_vector (rtype))
+        {
+          return gsk_sl_type_get_vector (scalar,
+                                         gsk_sl_type_get_length (rtype));
+        }
+      else if (gsk_sl_type_is_scalar (rtype))
+        {
+          return gsk_sl_type_get_scalar (scalar);
+        }
+      else
+        {
+          gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH, "Right operand is incompatible type for 
arithemtic arithmetic.");
+          return NULL;
+        }
+    }
+  else
+    {
+      gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH, "Left operand is incompatible type for arithemtic 
arithmetic.");
+      return NULL;
+    }
+}
+
+GSK_SL_BINARY_FUNC_SCALAR(gsk_sl_expression_division_int, gint32, x = y == 0 ? G_MAXINT32 : x / y;)
+GSK_SL_BINARY_FUNC_SCALAR(gsk_sl_expression_division_uint, guint32, x = y == 0 ? G_MAXUINT32 : x / y;)
+GSK_SL_BINARY_FUNC_SCALAR(gsk_sl_expression_division_float, float, x /= y;)
+GSK_SL_BINARY_FUNC_SCALAR(gsk_sl_expression_division_double, double, x /= y;)
+GSK_SL_BINARY_FUNC_SCALAR(gsk_sl_expression_division_int_inv, gint32, x = x == 0 ? G_MAXINT32 : y / x;)
+GSK_SL_BINARY_FUNC_SCALAR(gsk_sl_expression_division_uint_inv, guint32, x = x == 0 ? G_MAXUINT32 : y / x;)
+GSK_SL_BINARY_FUNC_SCALAR(gsk_sl_expression_division_float_inv, float, x = y / x;)
+GSK_SL_BINARY_FUNC_SCALAR(gsk_sl_expression_division_double_inv, double, x = y / x;)
+static void (* div_funcs[]) (gpointer, gpointer) = {
+  [GSK_SL_INT] = gsk_sl_expression_division_int,
+  [GSK_SL_UINT] = gsk_sl_expression_division_uint,
+  [GSK_SL_FLOAT] = gsk_sl_expression_division_float,
+  [GSK_SL_DOUBLE] = gsk_sl_expression_division_double,
+};
+static void (* div_inv_funcs[]) (gpointer, gpointer) = {
+  [GSK_SL_INT] = gsk_sl_expression_division_int_inv,
+  [GSK_SL_UINT] = gsk_sl_expression_division_uint_inv,
+  [GSK_SL_FLOAT] = gsk_sl_expression_division_float_inv,
+  [GSK_SL_DOUBLE] = gsk_sl_expression_division_double_inv,
+};
+
+static GskSlValue *
+gsk_sl_division_get_constant (GskSlType  *type,
+                              GskSlValue *lvalue,
+                              GskSlValue *rvalue)
+{
+  GskSlValue *result;
+  GskSlType *ltype, *rtype;
+  GskSlScalarType scalar;
+  gsize ln, rn;
+
+  scalar = gsk_sl_type_get_scalar_type (type);
+  lvalue = gsk_sl_value_convert_components (lvalue, scalar);
+  rvalue = gsk_sl_value_convert_components (rvalue, scalar);
+  ltype = gsk_sl_value_get_type (lvalue);
+  rtype = gsk_sl_value_get_type (rvalue);
+
+  ln = gsk_sl_type_get_n_components (ltype);
+  rn = gsk_sl_type_get_n_components (rtype);
+  if (ln == 1)
+    {
+      gsk_sl_value_componentwise (rvalue, div_inv_funcs[scalar], gsk_sl_value_get_data (lvalue));
+      gsk_sl_value_free (lvalue);
+      result = rvalue;
+    }
+  else if (rn == 1)
+    {
+      gsk_sl_value_componentwise (lvalue, div_funcs[scalar], gsk_sl_value_get_data (rvalue));
+      gsk_sl_value_free (rvalue);
+      result = lvalue;
+    }
+  else
+    {
+      guchar *ldata, *rdata;
+      gsize i, stride;
+
+      stride = gsk_sl_scalar_type_get_size (scalar);
+      ldata = gsk_sl_value_get_data (lvalue);
+      rdata = gsk_sl_value_get_data (rvalue);
+      for (i = 0; i < ln; i++)
+        {
+          div_funcs[scalar] (ldata + i * stride, rdata + i * stride);
+        }
+      gsk_sl_value_free (rvalue);
+      result = lvalue;
+    }
+
+  return result;
+}
+
+static guint32
+gsk_sl_division_write_spv (GskSpvWriter *writer,
+                           GskSlType    *type,
+                           GskSlType    *ltype,
+                           guint32       left_id,
+                           GskSlType    *rtype,
+                           guint32       right_id)
+{
+  if (gsk_sl_type_get_scalar_type (ltype) != gsk_sl_type_get_scalar_type (type))
+    {
+      GskSlType *new_type = gsk_sl_type_get_matching (ltype, gsk_sl_type_get_scalar_type (type));
+      left_id = gsk_spv_writer_convert (writer, left_id, ltype, new_type);
+      ltype = new_type;
+    }
+  if (gsk_sl_type_get_scalar_type (rtype) != gsk_sl_type_get_scalar_type (type))
+    {
+      GskSlType *new_type = gsk_sl_type_get_matching (rtype, gsk_sl_type_get_scalar_type (type));
+      right_id = gsk_spv_writer_convert (writer, right_id, rtype, new_type);
+      rtype = new_type;
+    }
+
+  if (gsk_sl_type_is_matrix (ltype))
+    {
+      if (gsk_sl_type_is_matrix (rtype))
+        {
+          GskSlType *col_type = gsk_sl_type_get_index_type (ltype);
+          gsize cols = gsk_sl_type_get_length (ltype);
+          gsize c;
+          guint32 left_part_id, right_part_id, ids[cols];
+
+          for (c = 0; c < cols; c++)
+            {
+              left_part_id = gsk_spv_writer_composite_extract (writer, 
+                                                               col_type,
+                                                               left_id,
+                                                               (guint32[1]) { c }, 1);
+              right_part_id = gsk_spv_writer_composite_extract (writer, 
+                                                                col_type,
+                                                                right_id,
+                                                                (guint32[1]) { c }, 1);
+              ids[c] = gsk_spv_writer_f_div (writer,
+                                             col_type,
+                                             left_part_id,
+                                             right_part_id);
+            }
+
+          return gsk_spv_writer_composite_construct (writer, 
+                                                     type,
+                                                     ids,
+                                                     cols);
+        }
+      else if (gsk_sl_type_is_scalar (rtype))
+        {
+          guint32 tmp_id;
+
+          tmp_id = gsk_spv_writer_f_div (writer,
+                                         rtype,
+                                         gsk_spv_writer_get_id_for_one (writer, gsk_sl_type_get_scalar_type 
(type)),
+                                         right_id);
+
+          return gsk_spv_writer_matrix_times_scalar (writer,
+                                                     type,
+                                                     left_id,
+                                                     tmp_id);
+        }
+      else
+        {
+          g_assert_not_reached ();
+          return 0;
+        }
+    }
+  else if (gsk_sl_type_is_matrix (rtype))
+    {
+      guint32 tmp_id;
+
+      tmp_id = gsk_spv_writer_f_div (writer,
+                                     ltype,
+                                     gsk_spv_writer_get_id_for_one (writer, gsk_sl_type_get_scalar_type 
(type)),
+                                     left_id);
+      return gsk_spv_writer_matrix_times_scalar (writer,
+                                                 type,
+                                                 right_id,
+                                                 tmp_id);
+    }
+  else
+    {
+      /* ltype and rtype are not matrices */
+
+      if (gsk_sl_type_is_scalar (ltype) && gsk_sl_type_is_vector (rtype))
+        {
+           guint32 tmp_id = gsk_spv_writer_composite_construct (writer,
+                                                                type,
+                                                                (guint32[4]) { left_id, left_id, left_id, 
left_id },
+                                                                gsk_sl_type_get_length (rtype));
+           left_id = tmp_id;
+        }
+      else if (gsk_sl_type_is_scalar (rtype) && gsk_sl_type_is_vector (ltype))
+        {
+           guint32 tmp_id = gsk_spv_writer_composite_construct (writer,
+                                                                type,
+                                                                (guint32[4]) { right_id, right_id, right_id, 
right_id },
+                                                                gsk_sl_type_get_length (ltype));
+           right_id = tmp_id;
+        }
+
+      /* ltype and rtype have the same number of components now */
+
+      switch (gsk_sl_type_get_scalar_type (type))
+        {
+        case GSK_SL_FLOAT:
+        case GSK_SL_DOUBLE:
+          return gsk_spv_writer_f_div (writer, type, left_id, right_id);
+
+        case GSK_SL_INT:
+          return gsk_spv_writer_s_div (writer, type, left_id, right_id);
+
+        case GSK_SL_UINT:
+          return gsk_spv_writer_u_div (writer, type, left_id, right_id);
+
+        case GSK_SL_VOID:
+        case GSK_SL_BOOL:
+        default:
+          g_assert_not_reached ();
+          return 0;
+        }
+    }
+}
+
+static const GskSlBinary GSK_SL_BINARY_DIVISION = {
+  "/",
+  gsk_sl_arithmetic_check_type,
+  gsk_sl_division_get_constant,
+  gsk_sl_division_write_spv
+};
+
+/* UNIMPLEMENTED */
+
+static GskSlValue *
+gsk_sl_unimplemented_get_constant (GskSlType  *type,
+                                   GskSlValue *lvalue,
+                                   GskSlValue *rvalue)
+{
+  g_assert_not_reached ();
+
+  return NULL;
+}
+
+static guint32
+gsk_sl_unimplemented_write_spv (GskSpvWriter *writer,
+                                GskSlType    *type,
+                                GskSlType    *ltype,
+                                guint32       left_id,
+                                GskSlType    *rtype,
+                                guint32       right_id)
+{
+  g_assert_not_reached ();
+
+  return 0;
+}
+
+static const GskSlBinary GSK_SL_BINARY_ADDITION = {
+  "+",
+  gsk_sl_arithmetic_check_type,
+  gsk_sl_unimplemented_get_constant,
+  gsk_sl_unimplemented_write_spv
+};
+
+static const GskSlBinary GSK_SL_BINARY_SUBTRACTION = {
+  "-",
+  gsk_sl_arithmetic_check_type,
+  gsk_sl_unimplemented_get_constant,
+  gsk_sl_unimplemented_write_spv
+};
+
 /* API */
 
 const char *
@@ -460,10 +826,22 @@ gsk_sl_binary_get_for_token (GskSlTokenType token)
     case GSK_SL_TOKEN_MUL_ASSIGN:
       return &GSK_SL_BINARY_MULTIPLICATION;
 
+    case GSK_SL_TOKEN_SLASH:
     case GSK_SL_TOKEN_DIV_ASSIGN:
+      return &GSK_SL_BINARY_DIVISION;
+
+    case GSK_SL_TOKEN_PERCENT:
     case GSK_SL_TOKEN_MOD_ASSIGN:
+      return NULL;
+
+    case GSK_SL_TOKEN_PLUS:
     case GSK_SL_TOKEN_ADD_ASSIGN:
+      return &GSK_SL_BINARY_ADDITION;
+
+    case GSK_SL_TOKEN_DASH:
     case GSK_SL_TOKEN_SUB_ASSIGN:
+      return &GSK_SL_BINARY_SUBTRACTION;
+
     case GSK_SL_TOKEN_LEFT_ASSIGN:
     case GSK_SL_TOKEN_RIGHT_ASSIGN:
     case GSK_SL_TOKEN_AND_ASSIGN:
diff --git a/gsk/gskslexpression.c b/gsk/gskslexpression.c
index 04b4dcc..f73f78d 100644
--- a/gsk/gskslexpression.c
+++ b/gsk/gskslexpression.c
@@ -172,11 +172,11 @@ static const GskSlExpressionClass GSK_SL_EXPRESSION_ASSIGNMENT = {
   gsk_sl_expression_assignment_write_spv
 };
 
-/* MULTIPLICATION */
+/* BINARY */
 
-typedef struct _GskSlExpressionMultiplication GskSlExpressionMultiplication;
+typedef struct _GskSlExpressionBinary GskSlExpressionBinary;
 
-struct _GskSlExpressionMultiplication {
+struct _GskSlExpressionBinary {
   GskSlExpression parent;
 
   const GskSlBinary *binary;
@@ -186,529 +186,81 @@ struct _GskSlExpressionMultiplication {
 };
 
 static void
-gsk_sl_expression_multiplication_free (GskSlExpression *expression)
+gsk_sl_expression_binary_free (GskSlExpression *expression)
 {
-  GskSlExpressionMultiplication *multiplication = (GskSlExpressionMultiplication *) expression;
+  GskSlExpressionBinary *binary = (GskSlExpressionBinary *) expression;
 
-  gsk_sl_expression_unref (multiplication->left);
-  gsk_sl_expression_unref (multiplication->right);
-  gsk_sl_type_unref (multiplication->type);
+  gsk_sl_expression_unref (binary->left);
+  gsk_sl_expression_unref (binary->right);
+  gsk_sl_type_unref (binary->type);
 
-  g_slice_free (GskSlExpressionMultiplication, multiplication);
+  g_slice_free (GskSlExpressionBinary, binary);
 }
 
 static void
-gsk_sl_expression_multiplication_print (const GskSlExpression *expression,
-                                        GskSlPrinter          *printer)
+gsk_sl_expression_binary_print (const GskSlExpression *expression,
+                                GskSlPrinter          *printer)
 {
-  GskSlExpressionMultiplication *multiplication = (GskSlExpressionMultiplication *) expression;
+  GskSlExpressionBinary *binary = (GskSlExpressionBinary *) expression;
 
-  gsk_sl_expression_print (multiplication->left, printer);
+  gsk_sl_expression_print (binary->left, printer);
   gsk_sl_printer_append (printer, " ");
-  gsk_sl_printer_append (printer, gsk_sl_binary_get_sign (multiplication->binary));
+  gsk_sl_printer_append (printer, gsk_sl_binary_get_sign (binary->binary));
   gsk_sl_printer_append (printer, " ");
-  gsk_sl_expression_print (multiplication->right, printer);
+  gsk_sl_expression_print (binary->right, printer);
 }
 
 static GskSlType *
-gsk_sl_expression_multiplication_get_return_type (const GskSlExpression *expression)
+gsk_sl_expression_binary_get_return_type (const GskSlExpression *expression)
 {
-  GskSlExpressionMultiplication *multiplication = (GskSlExpressionMultiplication *) expression;
-
-  return multiplication->type;
-}
+  GskSlExpressionBinary *binary = (GskSlExpressionBinary *) expression;
 
-#define GSK_SL_OPERATION_FUNC_SCALAR(func,type,...) \
-static void \
-func (gpointer value, gpointer scalar) \
-{ \
-  type x = *(type *) value; \
-  type y = *(type *) scalar; \
-  __VA_ARGS__ \
-  *(type *) value = x; \
+  return binary->type;
 }
-GSK_SL_OPERATION_FUNC_SCALAR(gsk_sl_expression_division_int, gint32, x = y == 0 ? G_MAXINT32 : x / y;)
-GSK_SL_OPERATION_FUNC_SCALAR(gsk_sl_expression_division_uint, guint32, x = y == 0 ? G_MAXUINT32 : x / y;)
-GSK_SL_OPERATION_FUNC_SCALAR(gsk_sl_expression_division_float, float, x /= y;)
-GSK_SL_OPERATION_FUNC_SCALAR(gsk_sl_expression_division_double, double, x /= y;)
-GSK_SL_OPERATION_FUNC_SCALAR(gsk_sl_expression_division_int_inv, gint32, x = x == 0 ? G_MAXINT32 : y / x;)
-GSK_SL_OPERATION_FUNC_SCALAR(gsk_sl_expression_division_uint_inv, guint32, x = x == 0 ? G_MAXUINT32 : y / x;)
-GSK_SL_OPERATION_FUNC_SCALAR(gsk_sl_expression_division_float_inv, float, x = y / x;)
-GSK_SL_OPERATION_FUNC_SCALAR(gsk_sl_expression_division_double_inv, double, x = y / x;)
-static void (* div_funcs[]) (gpointer, gpointer) = {
-  [GSK_SL_INT] = gsk_sl_expression_division_int,
-  [GSK_SL_UINT] = gsk_sl_expression_division_uint,
-  [GSK_SL_FLOAT] = gsk_sl_expression_division_float,
-  [GSK_SL_DOUBLE] = gsk_sl_expression_division_double,
-};
-static void (* div_inv_funcs[]) (gpointer, gpointer) = {
-  [GSK_SL_INT] = gsk_sl_expression_division_int_inv,
-  [GSK_SL_UINT] = gsk_sl_expression_division_uint_inv,
-  [GSK_SL_FLOAT] = gsk_sl_expression_division_float_inv,
-  [GSK_SL_DOUBLE] = gsk_sl_expression_division_double_inv,
-};
-
 
 static GskSlValue *
-gsk_sl_expression_multiplication_get_constant (const GskSlExpression *expression)
+gsk_sl_expression_binary_get_constant (const GskSlExpression *expression)
 {
-  const GskSlExpressionMultiplication *multiplication = (const GskSlExpressionMultiplication *) expression;
+  const GskSlExpressionBinary *binary = (const GskSlExpressionBinary *) expression;
   GskSlValue *lvalue, *rvalue;
 
-  lvalue = gsk_sl_expression_get_constant (multiplication->left);
+  lvalue = gsk_sl_expression_get_constant (binary->left);
   if (lvalue == NULL)
     return NULL;
-  rvalue = gsk_sl_expression_get_constant (multiplication->right);
+  rvalue = gsk_sl_expression_get_constant (binary->right);
   if (rvalue == NULL)
     {
       gsk_sl_value_free (lvalue);
       return NULL;
     }
 
-  return gsk_sl_binary_get_constant (multiplication->binary,
-                                     multiplication->type,
+  return gsk_sl_binary_get_constant (binary->binary,
+                                     binary->type,
                                      lvalue,
                                      rvalue);
 }
 
 static guint32
-gsk_sl_expression_multiplication_write_spv (const GskSlExpression *expression,
+gsk_sl_expression_binary_write_spv (const GskSlExpression *expression,
                                             GskSpvWriter          *writer)
 {
-  const GskSlExpressionMultiplication *multiplication = (const GskSlExpressionMultiplication *) expression;
+  const GskSlExpressionBinary *binary = (const GskSlExpressionBinary *) expression;
 
-  return gsk_sl_binary_write_spv (multiplication->binary,
+  return gsk_sl_binary_write_spv (binary->binary,
                                   writer,
-                                  multiplication->type,
-                                  gsk_sl_expression_get_return_type (multiplication->left),
-                                  gsk_sl_expression_write_spv (multiplication->left, writer),
-                                  gsk_sl_expression_get_return_type (multiplication->right),
-                                  gsk_sl_expression_write_spv (multiplication->right, writer));
-}
-
-static const GskSlExpressionClass GSK_SL_EXPRESSION_MULTIPLICATION = {
-  gsk_sl_expression_multiplication_free,
-  gsk_sl_expression_multiplication_print,
-  gsk_sl_expression_multiplication_get_return_type,
-  gsk_sl_expression_multiplication_get_constant,
-  gsk_sl_expression_multiplication_write_spv
-};
-
-/* ARITHMETIC */
-
-static GskSlType *
-gsk_sl_expression_arithmetic_type_check (GskSlPreprocessor *preproc,
-                                         GskSlType         *ltype,
-                                         GskSlType         *rtype)
-{
-  GskSlScalarType scalar;
-
-  if (gsk_sl_scalar_type_can_convert (gsk_sl_type_get_scalar_type (ltype),
-                                      gsk_sl_type_get_scalar_type (rtype)))
-    scalar = gsk_sl_type_get_scalar_type (ltype);
-  else if (gsk_sl_scalar_type_can_convert (gsk_sl_type_get_scalar_type (rtype),
-                                           gsk_sl_type_get_scalar_type (ltype)))
-    scalar = gsk_sl_type_get_scalar_type (rtype);
-  else
-    {
-      gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
-                                 "Operand types %s and %s do not share compatible scalar types.",
-                                 gsk_sl_type_get_name (ltype), gsk_sl_type_get_name (rtype));
-      return NULL;
-    }
-
-  if (gsk_sl_type_is_matrix (ltype))
-    {
-      if (gsk_sl_type_is_matrix (rtype))
-        {
-          if (gsk_sl_type_can_convert (ltype, rtype))
-            {
-              return ltype;
-            }
-          else if (gsk_sl_type_can_convert (rtype, ltype))
-            {
-              return rtype;
-            }
-          else
-            {
-              gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
-                                         "Matrix types %s and %s have different size.",
-                                         gsk_sl_type_get_name (ltype), gsk_sl_type_get_name (rtype));
-              return NULL;
-            }
-        }
-      else if (gsk_sl_type_is_vector (rtype))
-        {
-          gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
-                                     "Cannot perform arithmetic arithmetic between matrix and vector.");
-          return NULL;
-        }
-      else if (gsk_sl_type_is_scalar (rtype))
-        {
-          return gsk_sl_type_get_matrix (scalar,
-                                         gsk_sl_type_get_length (ltype),
-                                         gsk_sl_type_get_length (gsk_sl_type_get_index_type (ltype)));
-        }
-      else
-        {
-          gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
-                                     "Right operand is incompatible type for arithemtic arithmetic.");
-          return NULL;
-        }
-    }
-  else if (gsk_sl_type_is_vector (ltype))
-    {
-      if (gsk_sl_type_is_matrix (rtype))
-        {
-          gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH, "Cannot perform arithmetic arithmetic between 
vector and matrix.");
-          return NULL;
-        }
-      else if (gsk_sl_type_is_vector (rtype))
-        {
-          if (gsk_sl_type_get_length (ltype) != gsk_sl_type_get_length (rtype))
-            {
-              gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
-                                         "Vector operands %s and %s to arithmetic arithmetic have different 
length.",
-                                         gsk_sl_type_get_name (ltype), gsk_sl_type_get_name (rtype));
-              return NULL;
-            }
-          return gsk_sl_type_get_vector (scalar, gsk_sl_type_get_length (ltype));
-        }
-      else if (gsk_sl_type_is_scalar (rtype))
-        {
-          return gsk_sl_type_get_vector (scalar,
-                                         gsk_sl_type_get_length (ltype));
-        }
-      else
-        {
-          gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
-                                     "Right operand is incompatible type for arithemtic arithmetic.");
-          return NULL;
-        }
-    }
-  else if (gsk_sl_type_is_scalar (ltype))
-    {
-      if (gsk_sl_type_is_matrix (rtype))
-        {
-          return gsk_sl_type_get_matrix (scalar,
-                                         gsk_sl_type_get_length (rtype),
-                                         gsk_sl_type_get_length (gsk_sl_type_get_index_type (rtype)));
-        }
-      else if (gsk_sl_type_is_vector (rtype))
-        {
-          return gsk_sl_type_get_vector (scalar,
-                                         gsk_sl_type_get_length (rtype));
-        }
-      else if (gsk_sl_type_is_scalar (rtype))
-        {
-          return gsk_sl_type_get_scalar (scalar);
-        }
-      else
-        {
-          gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH, "Right operand is incompatible type for 
arithemtic arithmetic.");
-          return NULL;
-        }
-    }
-  else
-    {
-      gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH, "Left operand is incompatible type for arithemtic 
arithmetic.");
-      return NULL;
-    }
-}
-
-typedef struct _GskSlExpressionArithmetic GskSlExpressionArithmetic;
-
-struct _GskSlExpressionArithmetic {
-  GskSlExpression parent;
-
-  GskSlType *type;
-  GskSlExpression *left;
-  GskSlExpression *right;
-};
-
-static void
-gsk_sl_expression_arithmetic_free (GskSlExpression *expression)
-{
-  GskSlExpressionArithmetic *arithmetic = (GskSlExpressionArithmetic *) expression;
-
-  gsk_sl_type_unref (arithmetic->type);
-  gsk_sl_expression_unref (arithmetic->left);
-  gsk_sl_expression_unref (arithmetic->right);
-
-  g_slice_free (GskSlExpressionArithmetic, arithmetic);
-}
-
-static void
-gsk_sl_expression_division_print (const GskSlExpression *expression,
-                                  GskSlPrinter          *printer)
-{
-  GskSlExpressionArithmetic *arithmetic = (GskSlExpressionArithmetic *) expression;
-
-  gsk_sl_expression_print (arithmetic->left, printer);
-  gsk_sl_printer_append (printer, " / ");
-  gsk_sl_expression_print (arithmetic->right, printer);
-}
-
-static void
-gsk_sl_expression_addition_print (const GskSlExpression *expression,
-                                  GskSlPrinter          *printer)
-{
-  GskSlExpressionArithmetic *arithmetic = (GskSlExpressionArithmetic *) expression;
-
-  gsk_sl_expression_print (arithmetic->left, printer);
-  gsk_sl_printer_append (printer, " + ");
-  gsk_sl_expression_print (arithmetic->right, printer);
-}
-
-static void
-gsk_sl_expression_subtraction_print (const GskSlExpression *expression,
-                                     GskSlPrinter          *printer)
-{
-  GskSlExpressionArithmetic *arithmetic = (GskSlExpressionArithmetic *) expression;
-
-  gsk_sl_expression_print (arithmetic->left, printer);
-  gsk_sl_printer_append (printer, " - ");
-  gsk_sl_expression_print (arithmetic->right, printer);
-}
-
-static GskSlType *
-gsk_sl_expression_arithmetic_get_return_type (const GskSlExpression *expression)
-{
-  GskSlExpressionArithmetic *arithmetic = (GskSlExpressionArithmetic *) expression;
-
-  return arithmetic->type;
-}
-
-static GskSlValue *
-gsk_sl_expression_division_get_constant (const GskSlExpression *expression)
-{
-  const GskSlExpressionArithmetic *arithmetic = (const GskSlExpressionArithmetic *) expression;
-  GskSlValue *result, *lvalue, *rvalue;
-  GskSlType *ltype, *rtype;
-  GskSlScalarType scalar;
-  gsize ln, rn;
-
-  scalar = gsk_sl_type_get_scalar_type (arithmetic->type);
-  lvalue = gsk_sl_expression_get_constant (arithmetic->left);
-  if (lvalue == NULL)
-    return NULL;
-  rvalue = gsk_sl_expression_get_constant (arithmetic->right);
-  if (rvalue == NULL)
-    {
-      gsk_sl_value_free (lvalue);
-      return NULL;
-    }
-  lvalue = gsk_sl_value_convert_components (lvalue, scalar);
-  rvalue = gsk_sl_value_convert_components (rvalue, scalar);
-  ltype = gsk_sl_value_get_type (lvalue);
-  rtype = gsk_sl_value_get_type (rvalue);
-
-  ln = gsk_sl_type_get_n_components (ltype);
-  rn = gsk_sl_type_get_n_components (rtype);
-  if (ln == 1)
-    {
-      gsk_sl_value_componentwise (rvalue, div_inv_funcs[scalar], gsk_sl_value_get_data (lvalue));
-      gsk_sl_value_free (lvalue);
-      result = rvalue;
-    }
-  else if (rn == 1)
-    {
-      gsk_sl_value_componentwise (lvalue, div_funcs[scalar], gsk_sl_value_get_data (rvalue));
-      gsk_sl_value_free (rvalue);
-      result = lvalue;
-    }
-  else
-    {
-      guchar *ldata, *rdata;
-      gsize i, stride;
-
-      stride = gsk_sl_scalar_type_get_size (scalar);
-      ldata = gsk_sl_value_get_data (lvalue);
-      rdata = gsk_sl_value_get_data (rvalue);
-      for (i = 0; i < ln; i++)
-        {
-          div_funcs[scalar] (ldata + i * stride, rdata + i * stride);
-        }
-      gsk_sl_value_free (rvalue);
-      result = lvalue;
-    }
-
-  return result;
-}
-
-static guint32
-gsk_sl_expression_division_write_spv (const GskSlExpression *expression,
-                                      GskSpvWriter          *writer)
-{
-  const GskSlExpressionArithmetic *arithmetic = (const GskSlExpressionArithmetic *) expression;
-  GskSlType *ltype, *rtype;
-  guint32 left_id, right_id;
-
-  ltype = gsk_sl_expression_get_return_type (arithmetic->left);
-  rtype = gsk_sl_expression_get_return_type (arithmetic->right);
-
-  left_id = gsk_sl_expression_write_spv (arithmetic->left, writer);
-  if (gsk_sl_type_get_scalar_type (ltype) != gsk_sl_type_get_scalar_type (arithmetic->type))
-    {
-      GskSlType *new_type = gsk_sl_type_get_matching (ltype, gsk_sl_type_get_scalar_type (arithmetic->type));
-      left_id = gsk_spv_writer_convert (writer, left_id, ltype, new_type);
-      ltype = new_type;
-    }
-  right_id = gsk_sl_expression_write_spv (arithmetic->right, writer);
-  if (gsk_sl_type_get_scalar_type (rtype) != gsk_sl_type_get_scalar_type (arithmetic->type))
-    {
-      GskSlType *new_type = gsk_sl_type_get_matching (rtype, gsk_sl_type_get_scalar_type (arithmetic->type));
-      right_id = gsk_spv_writer_convert (writer, right_id, rtype, new_type);
-      rtype = new_type;
-    }
-
-  if (gsk_sl_type_is_matrix (ltype))
-    {
-      if (gsk_sl_type_is_matrix (rtype))
-        {
-          GskSlType *col_type = gsk_sl_type_get_index_type (ltype);
-          gsize cols = gsk_sl_type_get_length (ltype);
-          gsize c;
-          guint32 left_part_id, right_part_id, ids[cols];
-
-          for (c = 0; c < cols; c++)
-            {
-              left_part_id = gsk_spv_writer_composite_extract (writer, 
-                                                               col_type,
-                                                               left_id,
-                                                               (guint32[1]) { 1 }, c);
-              right_part_id = gsk_spv_writer_composite_extract (writer, 
-                                                                col_type,
-                                                                right_id,
-                                                                (guint32[1]) { 1 }, c);
-              ids[c] = gsk_spv_writer_f_div (writer,
-                                             col_type,
-                                             left_part_id,
-                                             right_part_id);
-            }
-
-          return gsk_spv_writer_composite_construct (writer, 
-                                                     arithmetic->type,
-                                                     ids,
-                                                     cols);
-        }
-      else if (gsk_sl_type_is_scalar (rtype))
-        {
-          guint32 tmp_id;
-
-          tmp_id = gsk_spv_writer_f_div (writer,
-                                         rtype,
-                                         gsk_spv_writer_get_id_for_one (writer, gsk_sl_type_get_scalar_type 
(arithmetic->type)),
-                                         right_id);
-
-          return gsk_spv_writer_matrix_times_scalar (writer,
-                                                     arithmetic->type,
-                                                     left_id,
-                                                     tmp_id);
-        }
-      else
-        {
-          g_assert_not_reached ();
-          return 0;
-        }
-    }
-  else if (gsk_sl_type_is_matrix (rtype))
-    {
-      guint32 tmp_id;
-
-      tmp_id = gsk_spv_writer_f_div (writer,
-                                     ltype,
-                                     gsk_spv_writer_get_id_for_one (writer, gsk_sl_type_get_scalar_type 
(arithmetic->type)),
-                                     left_id);
-      return gsk_spv_writer_matrix_times_scalar (writer,
-                                                 arithmetic->type,
-                                                 right_id,
-                                                 tmp_id);
-    }
-  else
-    {
-      /* ltype and rtype are not matrices */
-
-      if (gsk_sl_type_is_scalar (ltype) && gsk_sl_type_is_vector (rtype))
-        {
-           guint32 tmp_id = gsk_spv_writer_composite_construct (writer,
-                                                                arithmetic->type,
-                                                                (guint32[4]) { left_id, left_id, left_id, 
left_id },
-                                                                gsk_sl_type_get_length (rtype));
-           left_id = tmp_id;
-        }
-      else if (gsk_sl_type_is_scalar (rtype) && gsk_sl_type_is_vector (ltype))
-        {
-           guint32 tmp_id = gsk_spv_writer_composite_construct (writer,
-                                                                arithmetic->type,
-                                                                (guint32[4]) { right_id, right_id, right_id, 
right_id },
-                                                                gsk_sl_type_get_length (ltype));
-           right_id = tmp_id;
-        }
-
-      /* ltype and rtype have the same number of components now */
-
-      switch (gsk_sl_type_get_scalar_type (arithmetic->type))
-        {
-        case GSK_SL_FLOAT:
-        case GSK_SL_DOUBLE:
-          return gsk_spv_writer_f_div (writer, arithmetic->type, left_id, right_id);
-
-        case GSK_SL_INT:
-          return gsk_spv_writer_s_div (writer, arithmetic->type, left_id, right_id);
-
-        case GSK_SL_UINT:
-          return gsk_spv_writer_u_div (writer, arithmetic->type, left_id, right_id);
-
-        case GSK_SL_VOID:
-        case GSK_SL_BOOL:
-        default:
-          g_assert_not_reached ();
-          return 0;
-        }
-    }
-}
-
-static GskSlValue *
-gsk_sl_expression_arithmetic_get_constant (const GskSlExpression *expression)
-{
-  //GskSlExpressionArithmetic *arithmetic = (const GskSlExpressionArithmetic *) expression;
-
-  /* FIXME: These need constant evaluations */
-  return NULL;
-}
-
-static guint32
-gsk_sl_expression_arithmetic_write_spv (const GskSlExpression *expression,
-                                        GskSpvWriter          *writer)
-{
-  g_assert_not_reached ();
-
-  return 0;
-}
-
-static const GskSlExpressionClass GSK_SL_EXPRESSION_DIVISION = {
-  gsk_sl_expression_arithmetic_free,
-  gsk_sl_expression_division_print,
-  gsk_sl_expression_arithmetic_get_return_type,
-  gsk_sl_expression_division_get_constant,
-  gsk_sl_expression_division_write_spv
-};
-
-static const GskSlExpressionClass GSK_SL_EXPRESSION_ADDITION = {
-  gsk_sl_expression_arithmetic_free,
-  gsk_sl_expression_addition_print,
-  gsk_sl_expression_arithmetic_get_return_type,
-  gsk_sl_expression_arithmetic_get_constant,
-  gsk_sl_expression_arithmetic_write_spv
-};
-
-static const GskSlExpressionClass GSK_SL_EXPRESSION_SUBTRACTION = {
-  gsk_sl_expression_arithmetic_free,
-  gsk_sl_expression_subtraction_print,
-  gsk_sl_expression_arithmetic_get_return_type,
-  gsk_sl_expression_arithmetic_get_constant,
-  gsk_sl_expression_arithmetic_write_spv
+                                  binary->type,
+                                  gsk_sl_expression_get_return_type (binary->left),
+                                  gsk_sl_expression_write_spv (binary->left, writer),
+                                  gsk_sl_expression_get_return_type (binary->right),
+                                  gsk_sl_expression_write_spv (binary->right, writer));
+}
+
+static const GskSlExpressionClass GSK_SL_EXPRESSION_BINARY = {
+  gsk_sl_expression_binary_free,
+  gsk_sl_expression_binary_print,
+  gsk_sl_expression_binary_get_return_type,
+  gsk_sl_expression_binary_get_constant,
+  gsk_sl_expression_binary_write_spv
 };
 
 /* OPERATION */
@@ -2578,46 +2130,25 @@ gsk_sl_expression_parse_multiplicative (GskSlScope        *scope,
 
       gsk_sl_preprocessor_consume (stream, NULL);
       right = gsk_sl_expression_parse_unary (scope, stream);
-      if (op == MUL)
+      if (op == MUL || op == DIV)
         {
           const GskSlBinary *binary;
           GskSlType *result_type;
 
-          binary = gsk_sl_binary_get_for_token (GSK_SL_TOKEN_STAR);
+          binary = gsk_sl_binary_get_for_token (op == MUL ? GSK_SL_TOKEN_STAR : GSK_SL_TOKEN_SLASH);
           result_type = gsk_sl_binary_check_type (binary,
                                                   stream,
                                                   gsk_sl_expression_get_return_type (expression),
                                                   gsk_sl_expression_get_return_type (right));
           if (result_type)
             {
-              GskSlExpressionMultiplication *multiplication;
-              multiplication = gsk_sl_expression_new (GskSlExpressionMultiplication, 
&GSK_SL_EXPRESSION_MULTIPLICATION);
-              multiplication->binary = binary;
-              multiplication->type = gsk_sl_type_ref (result_type);
-              multiplication->left = expression;
-              multiplication->right = right;
-              expression = (GskSlExpression *) multiplication;
-            }
-          else
-            {
-              gsk_sl_expression_unref ((GskSlExpression *) right);
-            }
-        }
-      else if (op == DIV)
-        {
-          GskSlType *result_type;
-
-          result_type = gsk_sl_expression_arithmetic_type_check (stream,
-                                                                 gsk_sl_expression_get_return_type 
(expression),
-                                                                 gsk_sl_expression_get_return_type (right));
-          if (result_type)
-            {
-              GskSlExpressionArithmetic *division;
-              division = gsk_sl_expression_new (GskSlExpressionArithmetic, &GSK_SL_EXPRESSION_DIVISION);
-              division->type = gsk_sl_type_ref (result_type);
-              division->left = expression;
-              division->right = right;
-              expression = (GskSlExpression *) division;
+              GskSlExpressionBinary *binary_expr;
+              binary_expr = gsk_sl_expression_new (GskSlExpressionBinary, &GSK_SL_EXPRESSION_BINARY);
+              binary_expr->binary = binary;
+              binary_expr->type = gsk_sl_type_ref (result_type);
+              binary_expr->left = expression;
+              binary_expr->right = right;
+              expression = (GskSlExpression *) binary_expr;
             }
           else
             {
@@ -2652,36 +2183,37 @@ gsk_sl_expression_parse_additive (GskSlScope        *scope,
                                   GskSlPreprocessor *stream)
 {
   const GskSlToken *token;
+  const GskSlBinary *binary;
   GskSlExpression *expression, *right;
   GskSlType *result_type;
-  enum { ADD, SUB } op;
 
   expression = gsk_sl_expression_parse_multiplicative (scope, stream);
 
   while (TRUE)
     {
       token = gsk_sl_preprocessor_get (stream);
-      if (gsk_sl_token_is (token, GSK_SL_TOKEN_PLUS))
-        op = ADD;
-      else if (gsk_sl_token_is (token, GSK_SL_TOKEN_DASH))
-        op = SUB;
-      else
+      if (!gsk_sl_token_is (token, GSK_SL_TOKEN_PLUS) &&
+          !gsk_sl_token_is (token, GSK_SL_TOKEN_DASH))
         return expression;
 
+      binary = gsk_sl_binary_get_for_token (token->type);
       gsk_sl_preprocessor_consume (stream, NULL);
-      right = gsk_sl_expression_parse_additive (scope, stream);
-      result_type = gsk_sl_expression_arithmetic_type_check (stream,
-                                                             gsk_sl_expression_get_return_type (expression),
-                                                             gsk_sl_expression_get_return_type (right));
+
+      right = gsk_sl_expression_parse_multiplicative (scope, stream);
+
+      result_type = gsk_sl_binary_check_type (binary,
+                                              stream,
+                                              gsk_sl_expression_get_return_type (expression),
+                                              gsk_sl_expression_get_return_type (right));
       if (result_type)
         {
-          GskSlExpressionArithmetic *arithmetic;
-
-          arithmetic = gsk_sl_expression_new (GskSlExpressionArithmetic, op == ADD ? 
&GSK_SL_EXPRESSION_ADDITION : &GSK_SL_EXPRESSION_SUBTRACTION);
-          arithmetic->type = gsk_sl_type_ref (result_type);
-          arithmetic->left = expression;
-          arithmetic->right = right;
-          expression = (GskSlExpression *) arithmetic;
+          GskSlExpressionBinary *binary_expr;
+          binary_expr = gsk_sl_expression_new (GskSlExpressionBinary, &GSK_SL_EXPRESSION_BINARY);
+          binary_expr->binary = binary;
+          binary_expr->type = gsk_sl_type_ref (result_type);
+          binary_expr->left = expression;
+          binary_expr->right = right;
+          expression = (GskSlExpression *) binary_expr;
         }
       else
         {



[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]