Commit 8286174

mo khan <mo@mokhan.ca>
2026-01-09 23:06:26
feat: format UNNEST and COALESCE statement
1 parent 3534005
Changed files (3)
src
tests
fixtures
src/formatter.rs
@@ -282,10 +282,110 @@ fn format_table_factor(table: &TableFactor, indent_level: usize) -> String {
             }
             result
         }
+        TableFactor::UNNEST {
+            array_exprs,
+            alias,
+            with_offset,
+            with_offset_alias,
+            ..
+        } => {
+            let has_complex = array_exprs.iter().any(contains_complex_expr);
+            let mut result = String::from("UNNEST(");
+
+            if has_complex {
+                let inner_indent = indent_level + 2;
+                result.push('\n');
+                for (i, expr) in array_exprs.iter().enumerate() {
+                    if i > 0 {
+                        result.push_str(",\n");
+                    }
+                    result.push_str(&" ".repeat(inner_indent));
+                    result.push_str(&format_expr(expr, inner_indent));
+                }
+                result.push('\n');
+                result.push_str(&" ".repeat(indent_level));
+            } else {
+                let exprs: Vec<String> = array_exprs
+                    .iter()
+                    .map(|e| format_expr(e, indent_level))
+                    .collect();
+                result.push_str(&exprs.join(", "));
+            }
+            result.push(')');
+
+            if let Some(alias) = alias {
+                result.push_str(&format!(" AS {}", alias.name.value));
+                if !alias.columns.is_empty() {
+                    let cols: Vec<String> =
+                        alias.columns.iter().map(|c| c.name.value.clone()).collect();
+                    result.push_str(&format!(" ({})", cols.join(", ")));
+                }
+            }
+
+            if *with_offset {
+                result.push_str(" WITH OFFSET");
+                if let Some(offset_alias) = with_offset_alias {
+                    result.push_str(&format!(" AS {}", offset_alias.value));
+                }
+            }
+
+            result
+        }
+        TableFactor::Function {
+            name, args, alias, ..
+        } => {
+            let has_complex = args.iter().any(|arg| match arg {
+                FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => contains_complex_expr(e),
+                FunctionArg::Named { arg: FunctionArgExpr::Expr(e), .. } => contains_complex_expr(e),
+                _ => false,
+            });
+
+            let mut result = name.to_string().to_uppercase();
+            result.push('(');
+
+            if has_complex {
+                let inner_indent = indent_level + 2;
+                result.push('\n');
+                for (i, arg) in args.iter().enumerate() {
+                    if i > 0 {
+                        result.push_str(",\n");
+                    }
+                    result.push_str(&" ".repeat(inner_indent));
+                    result.push_str(&format_table_function_arg(arg, inner_indent));
+                }
+                result.push('\n');
+                result.push_str(&" ".repeat(indent_level));
+            } else {
+                let formatted_args: Vec<String> = args
+                    .iter()
+                    .map(|a| format_table_function_arg(a, indent_level))
+                    .collect();
+                result.push_str(&formatted_args.join(", "));
+            }
+            result.push(')');
+
+            if let Some(alias) = alias {
+                result.push_str(&format!(" AS {}", alias.name.value));
+            }
+
+            result
+        }
         _ => table.to_string(),
     }
 }
 
+fn format_table_function_arg(arg: &FunctionArg, indent_level: usize) -> String {
+    match arg {
+        FunctionArg::Named { name, arg, .. } => {
+            format!("{} => {}", name.value, format_function_arg_expr(arg, indent_level))
+        }
+        FunctionArg::Unnamed(arg) => format_function_arg_expr(arg, indent_level),
+        FunctionArg::ExprNamed { name, arg, .. } => {
+            format!("{} => {}", format_expr(name, indent_level), format_function_arg_expr(arg, indent_level))
+        }
+    }
+}
+
 fn format_join(join: &Join, indent_level: usize) -> String {
     let mut result = String::new();
 
@@ -485,48 +585,61 @@ fn format_function(function: &Function, indent_level: usize) -> String {
     let mut result = function.name.to_string().to_uppercase();
     result.push('(');
 
-    if let FunctionArguments::List(list) = &function.args {
-        if let Some(DuplicateTreatment::Distinct) = &list.duplicate_treatment {
-            result.push_str("DISTINCT ");
-        }
+    let has_distinct = if let FunctionArguments::List(list) = &function.args {
+        matches!(list.duplicate_treatment, Some(DuplicateTreatment::Distinct))
+    } else {
+        false
+    };
+
+    if has_distinct {
+        result.push_str("DISTINCT ");
     }
 
-    let args: Vec<String> = match &function.args {
-        FunctionArguments::None => vec![],
+    match &function.args {
+        FunctionArguments::None => {}
         FunctionArguments::Subquery(subquery) => {
             let inner_indent = indent_level + 2;
-            vec![format!(
-                "\n{}\n{}",
-                format_query(subquery, inner_indent),
-                " ".repeat(indent_level)
-            )]
+            result.push('\n');
+            result.push_str(&format_query(subquery, inner_indent));
+            result.push('\n');
+            result.push_str(&" ".repeat(indent_level));
         }
-        FunctionArguments::List(list) => list
-            .args
-            .iter()
-            .map(|arg| match arg {
-                FunctionArg::Named { name, arg, .. } => {
-                    format!(
-                        "{} => {}",
-                        name.value,
-                        format_function_arg_expr(arg, indent_level)
-                    )
+        FunctionArguments::List(list) => {
+            let has_complex = list.args.iter().any(|arg| match arg {
+                FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => contains_complex_expr(e),
+                FunctionArg::Named { arg: FunctionArgExpr::Expr(e), .. } => {
+                    contains_complex_expr(e)
                 }
-                FunctionArg::Unnamed(arg) => format_function_arg_expr(arg, indent_level),
-                FunctionArg::ExprNamed { name, arg, .. } => {
-                    format!(
-                        "{} => {}",
-                        format_expr(name, indent_level),
-                        format_function_arg_expr(arg, indent_level)
-                    )
+                FunctionArg::ExprNamed { arg: FunctionArgExpr::Expr(e), .. } => {
+                    contains_complex_expr(e)
                 }
-            })
-            .collect(),
+                _ => false,
+            });
+
+            if has_complex {
+                let inner_indent = indent_level + 2;
+                result.push('\n');
+                for (i, arg) in list.args.iter().enumerate() {
+                    if i > 0 {
+                        result.push_str(",\n");
+                    }
+                    result.push_str(&" ".repeat(inner_indent));
+                    result.push_str(&format_table_function_arg(arg, inner_indent));
+                }
+                result.push('\n');
+                result.push_str(&" ".repeat(indent_level));
+            } else {
+                let args: Vec<String> = list
+                    .args
+                    .iter()
+                    .map(|arg| format_table_function_arg(arg, indent_level))
+                    .collect();
+                result.push_str(&args.join(", "));
+            }
+        }
     };
 
-    result.push_str(&args.join(", "));
     result.push(')');
-
     result
 }
 
@@ -641,6 +754,34 @@ fn is_simple_expr(expr: &Expr) -> bool {
     }
 }
 
+fn contains_complex_expr(expr: &Expr) -> bool {
+    match expr {
+        Expr::Subquery(_) => true,
+        Expr::InSubquery { .. } => true,
+        Expr::Function(f) => {
+            if let FunctionArguments::List(list) = &f.args {
+                list.args.iter().any(|arg| match arg {
+                    FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => contains_complex_expr(e),
+                    FunctionArg::Named { arg: FunctionArgExpr::Expr(e), .. } => {
+                        contains_complex_expr(e)
+                    }
+                    FunctionArg::ExprNamed { arg: FunctionArgExpr::Expr(e), .. } => {
+                        contains_complex_expr(e)
+                    }
+                    _ => false,
+                })
+            } else {
+                matches!(f.args, FunctionArguments::Subquery(_))
+            }
+        }
+        Expr::Nested(inner) => contains_complex_expr(inner),
+        Expr::BinaryOp { left, right, .. } => {
+            contains_complex_expr(left) || contains_complex_expr(right)
+        }
+        _ => false,
+    }
+}
+
 fn format_where_expr(expr: &Expr, indent_level: usize) -> String {
     match expr {
         Expr::BinaryOp { left, op, right } if matches!(op, BinaryOperator::And) => {
tests/fixtures/unnest_coalesce/input.sql
@@ -0,0 +1,1 @@
+SELECT id FROM UNNEST(COALESCE((SELECT ids FROM cached), (SELECT ids FROM fallback))) AS t(id)
\ No newline at end of file
tests/fixtures/unnest_coalesce/output.sql
@@ -0,0 +1,16 @@
+SELECT id
+FROM
+  UNNEST(
+    COALESCE(
+      (
+        SELECT ids
+        FROM
+          cached
+      ),
+      (
+        SELECT ids
+        FROM
+          fallback
+      )
+    )
+  ) AS t (id)
\ No newline at end of file