def zip [n] 'a 'b (as: [n]a) (bs: [n]b): [n](a,b) = map (\i -> (as[i], bs[i])) (iota n) def iota (n: i64) = 0..1.. x) (iota n) def concat 't (xs: []t) (ys: []t): *[]t = map (\i -> if i < length xs then xs[i] else ys[i - length xs]) (iota (length xs + length ys)) def rotate 't (r: i64) (xs: []t) = map (\i -> xs[(i+r) % length xs]) (iota (length xs)) def transpose [n] [m] 't (a: [n][m]t): [m][n]t = map (\i -> map (\j -> a[j,i]) (iota n)) (iota m) def flatten [n][m] 't (xs: [n][m]t): []t = map (\i -> xs[i/m, i%m]) (iota (n*m)) def unflatten 't (n: i64) (m: i64) (xs: []t): [n][m]t = map (\i -> map (\j -> xs[i*m+j]) (iota m)) (iota n) def div_rounding_up (x: i64) (y: i64) = (x + y - 1) / y def reduce_tree 'a (op: a -> a -> a) (ne: a) (as: []a): a = let as' = loop as while length as > 1 do map (\i -> let x = if i*2 >= length as then ne else as[i*2] let y = if i*2+1 >= length as then ne else as[i*2+1] in x `op` y) (iota (length as `div_rounding_up` 2)) in if length as' == 0 then ne else as'[0] def num_threads : i64 = 128 * 256 def reduce [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): a = let chunk_size = n `div_rounding_up` num_threads let partial_results = map (\t -> loop x = ne for i < chunk_size do let j = t + i * num_threads in if j < n then x `op` as[j] else x) (iota num_threads) in reduce_tree op ne partial_results def scan [n] 'a (op: a -> a -> a) (_ne: a) (as: [n]a): [n]a = let iters = i64.f32 (f32.ceil (f32.log2 (f32.i64 n))) in loop as for i < iters do map (\j -> if j < 2**i then as[j] else as[j] `op` as[j-2**i]) (iota n) def filter 'a (p: a -> bool) (as: []a): *[]a = let keep = map (\a -> if p a then 1 else 0) as let offsets = scan (+) 0 keep let num_to_keep = reduce (+) 0 keep in if num_to_keep == 0 then [] else scatter (replicate num_to_keep as[0]) (map (\(i, k) -> if k == 1 then i-1 else -1) (zip offsets keep)) as def stream_map 'a 'b (f: (c: i64) -> [c]a -> [c]b) (as: []a): []b = as |> unflatten (length as) 1 |> map (f 1) |> flatten def stream_red 'a 'b (op: b -> b -> b) (f: (c: i64) -> [c]a -> b) (as: []a): b = as |> unflatten (length as) 1 |> map (f 1) |> reduce op (f 0 [])