from networks access *;

network four_choose_two(bool named_inputs=true) {
  network N;
  
  tensor X[] = new tensor[];
  tensor TopId[] = new tensor[];
  string names[] = {"1", "3", "2", "6", "4", "5"};
  for (int i = 1; i <= 6; ++i) {
    X[i] = N.add_tensor((40*i, 0), 10, 10, pen=green);
    if (named_inputs) {
      X[i].label="$X^{(" + names[i-1] + ")}$";
    }
    TopId[i] = N.add_tensor((40*i + (i == 3 ? 5 : -5), -30), 5, 5);
    N.join(TopId[i].mid, TopId[i].mid + (0, 30));
  }

  tensor BotBig[] = new tensor[];
  tensor BotSmall[] = new tensor[];
  for (int i = 0; i < 3; ++i) {
    int xmid = 40*(2*i+1)+20;
    mode_join J = N.add_mode_join((xmid-5, -50));
    BotBig[i] = N.add_tensor((xmid-10, -80), 10, 10);
    BotSmall[i] = N.add_tensor((xmid+10, -80), 5, 5);
    N.add_path(J.mid{left} .. controls (TopId[2*i+1].mid.x, J.mid.y) .. TopId[2*i+1].mid-(0, 5){down});
    N.add_path(J.mid{right} .. controls (TopId[2*i+2].mid.x, J.mid.y) .. TopId[2*i+2].mid-(0,5){down});
    N.join(J.mid, BotBig[i].mid + (5,0));
    pair A = X[2*i+1].mid + (5-10*(i%2),-10), B = BotBig[i].mid - (5,-10);
    if (A.x == B.x) {
      N.join(A, B);
    } else {
      N.add_path(A{down} .. controls A-(0,30) and B+(0,10)..B{up});
    }
    N.add_path(X[2*i+2].mid+(5, -10){down} .. controls X[2*i+2].mid+(5, -50) and BotSmall[i].mid+(0, 20) .. BotSmall[i].mid+(0,5){up});
  }

  int xmid = 40*(2*1+1)+20;
  for (int i = 0; i < 2; ++i) {
    mode_join J = N.add_mode_join((xmid-10+20*i, -110-10*i));
    tensor conn[] = (i == 0) ? BotBig : BotSmall;
    for (int j = 0; j < 3; ++j) {
      N.add_path(J.mid--(conn[j].mid.x, J.mid.y)--conn[j].mid);
    }
  }

  // Execution

  execution_node exec[];
  for (int i = 1; i <= 6; ++i) {
    exec[i] = N.add_execution_node(TopId[i].mid + (-2, -20));
    N.exec_join(X[i].exec, exec[i], exec[i].mid + (3*(X[i].exec.mid.x - exec[i].mid.x), 15));
    N.exec_join(TopId[i].exec, exec[i]);
  }

  for (int i = 0; i < 3; ++i) {
    execution_node A = N.add_execution_node(BotBig[i].mid + (0, 20));
    N.exec_join(exec[2*i+1], A);
    N.exec_join(exec[2*i+2], A);
    execution_node B = N.add_execution_node(BotBig[i].mid - (5, 35));
    N.exec_join(BotBig[i].exec, B);
    N.exec_join(A, B, (A.mid+B.mid)/2 + (-25, 10));
    execution_node C = N.add_execution_node(BotSmall[i].mid - (5, 35));
    N.exec_join(B, C);
    N.exec_join(BotSmall[i].exec, C);
    exec[i] = C;
  }

  execution_node A = N.add_execution_node(exec[1].mid + (0, -20));
  N.exec_join(exec[0], A);
  N.exec_join(exec[1], A);
  execution_node B = N.add_execution_node(exec[2].mid + (0, -40));
  N.exec_join(exec[2], B);
  N.exec_join(A, B);
  return N;
}
