from networks access *;

bool contains(int[] A, int v) {
  for (int i = 0; i < A.length; ++i)
    if (A[i] == v) return true;
  return false;
}

int choose2(int k) {
  return quotient(k*(k-1), 2);
}

network binomial2(int v,
                  bool attach_inputs=false,
                  bool label_input_modes=false,
                  bool only_first_level=false,
                  bool label_muls=false,
                  int first_level_height=60
                  ) {
  int u = v % 3;
  int s = quotient(v, 3);

  network N;
  int P[][] = {{}, {}, {}, {}};
  for (int i = 1; i <= 2*s; ++i)
    P[1].push(i);
  for (int i = 3*s+1; i <= 3*s+u; ++i)
    P[1].push(i);
  for (int i = 1; i <= s; ++i)
    P[2].push(i);
  for (int i = 2*s+1; i <= 3*s+u; ++i)
    P[2].push(i);
  for (int i = s+1; i <= 3*s; ++i)
    P[3].push(i);
  
  tensor A[][] = new tensor[][];
  pair G[][] = new pair[][];
  for (int g = 1; g <= 3; ++g)
    G[g] = new pair[];
  for (int i = 1; i <= v; ++i) {
    for (int j = i+1; j <= v; ++j) {
      for (int g = 1; g <= 3; ++g) {
        if (contains(P[g], i) && contains(P[g], j)) {
          G[g].push((i, j));
          break;
        }
      }
    }
  }

  pair Gpos[] = {(0, 0), (0, 0)};
  for (int g = 1; g <= 2; ++g)
    Gpos.push(Gpos[g] + (25*G[g].length + 20, 0));

  mode_join J[][] = new mode_join[][];
  tensor A[][] = new tensor[][];
  for (int g = 1; g <= 3; ++g) {
    J[g] = new mode_join[];
    A[g] = new tensor[];
    pair mid = (Gpos[g] + Gpos[g] + (25*G[g].length-25, 0)) / 2;
    pair at = mid - 20*(P[g].length-1)/2 + (0, -first_level_height);
    for (int k = 0; k < G[g].length; ++k) {
      pair I = G[g][k];
      pair Apos = Gpos[g] + (25*k, attach_inputs ? 20 : 0);
      if (label_input_modes) {
        N.add_label(Label("$" + string(I.x) + "\;\," + string(I.y) + "$", Apos));
      }
      if (attach_inputs) {
        A[g].push(N.add_tensor(Apos - (0, 20), 10, 10, pen=green));
      }
    }
    for (int v: P[g]) {
      mode_join join = N.add_mode_join(at);
      if (only_first_level) {
        N.add_label(Label("$" + string(v) + "$", at + (0, -10)));
      }
      J[g].push(join);
      for (int k = 0; k < G[g].length; ++k) {
        pair I = G[g][k];
        if (I.x == v || I.y == v) {
          pair Apos = Gpos[g] + (25*k, 0);
          int shift = I.x == v ? -5 : 5;
          N.add_path(smooth_vertical_path(Apos + (shift, -10), join.mid));
        }
      }
      at += (20, 0);
    }
  }

  if (only_first_level) {
    return N;
  }

  string names[] = {"", "\alpha", "\beta", "\gamma"};
  tensor mul[][] = new tensor[][];
  for (int g = 1; g <= 3; ++g) {
    mul[g] = new tensor[];
    for (int i = 0; i < s; ++i) {
      pair pos = (J[g][2*i].mid + J[g][2*i+1].mid) / 2.0 + (0, -40);
      string L = label_muls ? "$" + names[g] + "$" : "";
      mul[g][i] = N.add_tensor(pos, 12, 12, label=L);
      if (label_muls) {
        int bot_shift = g == 1 ? 5 : -5;
        N.add_label(Label("0", pos+(bot_shift,8), p=magenta+fontsize(8pt)));
        N.add_label(Label("1", pos+(-bot_shift,8), p=magenta+fontsize(8pt)));
        N.add_label(Label("2", pos+(0,-8.5), p=magenta+fontsize(8pt)));
      }
      N.add_path(smooth_vertical_path(J[g][i].mid, pos + (-5, 10)));
      N.add_path(smooth_vertical_path(J[g][i+s].mid, pos + (5, 10)));
    }
  }

  for (int i = 0; i < s; ++i) {
    mode_join K = N.add_mode_join(mul[2][i].mid - (0, 40+10*i));
    N.add_path(K.mid--(mul[1][i].mid.x, K.mid.y)--mul[1][i].mid);
    N.join(K.mid, mul[2][i].mid);
    N.add_path(K.mid--(mul[3][i].mid.x, K.mid.y)--mul[3][i].mid);
  }

  for (int i = 0; i < u; ++i) {
    pair A = J[1][2*s+i].mid, B = J[2][2*s+i].mid;
    N.add_path(A -- A+(0,-60-10*i) -- B+(0,-60-10*i) -- B);
  }

  // Execution
  if (attach_inputs) {
    execution_node end[] = new execution_node[];
    for (int g = 1; g <= 3; ++g) {
      execution_node prev = A[g][0].exec;
      for (int i = 1; i < A[g].length; ++i) {
        execution_node next = N.add_execution_node((A[g][i].mid.x+ 5, J[g][0].mid.y - 10));
        if (i > 1) {
          N.exec_join(prev, next);
        } else {
          N.exec_join(prev, next, (prev.mid.x, next.mid.y));
        }
        N.exec_join(A[g][i].exec, next);
        prev = next;
      }
      for (int i = s-1; i >= 0; --i) {
        execution_node next = N.add_execution_node(mul[g][i].mid + (5, -25));
        if (i < s-1) {
          N.exec_join(prev, next);
        } else {
          N.exec_join(prev, next, (prev.mid.x, next.mid.y));
        }
        N.exec_join(mul[g][i].exec, next);
        prev = next;
      }
      end[g] = prev;
    }
    execution_node A = N.add_execution_node(end[1].mid + (5, -20));
    N.exec_join(end[1], A);
    N.exec_join(end[2], A, (end[2].mid.x, A.mid.y));
    execution_node B = N.add_execution_node(A.mid + (5, -20));
    N.exec_join(end[3], B, (end[3].mid.x, B.mid.y));
    N.exec_join(A, B);
  }
  return N;
}
