diff --git a/ssh/kex.c b/ssh/kex.c index 7198639..e0164a0 100644 --- a/ssh/kex.c +++ b/ssh/kex.c @@ -229,28 +229,41 @@ kex_send_kexinit(struct ssh *ssh) int kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh) { + Kex *kex = ssh->kex; char *ptr; u_int i, dlen; - Kex *kex = ssh->kex; + int r; debug("SSH2_MSG_KEXINIT received"); if (kex == NULL) - fatal("kex_input_kexinit: no kex, cannot rekey"); + return SSH_ERR_INVALID_ARGUMENT; ptr = ssh_packet_get_raw(ssh, &dlen); - buffer_append(&kex->peer, ptr, dlen); + if ((r = sshbuf_put(&kex->peer, ptr, dlen)) != 0) + return r; /* discard packet */ for (i = 0; i < KEX_COOKIE_LEN; i++) - ssh_packet_get_char(ssh); + if ((r = sshpkt_get_u8(ssh, NULL)) != 0) + return r; for (i = 0; i < PROPOSAL_MAX; i++) - xfree(ssh_packet_get_string(ssh, NULL)); - (void) ssh_packet_get_char(ssh); - (void) ssh_packet_get_int(ssh); - ssh_packet_check_eom(ssh); + if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) + return r; + if ((r = sshpkt_get_u8(ssh, NULL)) != 0 || + (r = sshpkt_get_u32(ssh, NULL)) != 0 || + (r = sshpkt_get_end(ssh)) != 0) + return r; - kex_kexinit_finish(ssh); - return 0; + /* XXX check error */ + if (!(kex->flags & KEX_INIT_SENT)) + kex_send_kexinit(ssh); + kex_choose_conf(ssh); + + if (kex->kex_type >= 0 && kex->kex_type < KEX_MAX && + kex->kex[kex->kex_type] != NULL) + return (kex->kex[kex->kex_type])(ssh); + + return SSH_ERR_INTERNAL_ERROR; } Kex * @@ -288,23 +301,6 @@ kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX]) return ssh->kex; } -static void -kex_kexinit_finish(struct ssh *ssh) -{ - Kex *kex = ssh->kex; - if (!(kex->flags & KEX_INIT_SENT)) - kex_send_kexinit(ssh); - - kex_choose_conf(ssh); - - if (kex->kex_type >= 0 && kex->kex_type < KEX_MAX && - kex->kex[kex->kex_type] != NULL) { - (kex->kex[kex->kex_type])(ssh); - } else { - fatal("Unsupported key exchange %d", kex->kex_type); - } -} - static void choose_enc(Enc *enc, char *client, char *server) { diff --git a/ssh/kex.h b/ssh/kex.h index b0fb70a..caac4de 100644 --- a/ssh/kex.h +++ b/ssh/kex.h @@ -125,7 +125,7 @@ struct Kex { struct sshkey *(*load_host_public_key)(int, struct ssh *); struct sshkey *(*load_host_private_key)(int, struct ssh *); int (*host_key_index)(struct sshkey *); - void (*kex[KEX_MAX])(struct ssh *); + int (*kex[KEX_MAX])(struct ssh *); /* kex specific state */ DH *dh; /* DH */ int min, max, nbits; /* GEX */ @@ -150,12 +150,12 @@ int kex_derive_keys(struct ssh *, u_char *, u_int, BIGNUM *); Newkeys *kex_get_newkeys(struct ssh *, int); -void kexdh_client(struct ssh *); -void kexdh_server(struct ssh *); -void kexgex_client(struct ssh *); -void kexgex_server(struct ssh *); -void kexecdh_client(struct ssh *); -void kexecdh_server(struct ssh *); +int kexdh_client(struct ssh *); +int kexdh_server(struct ssh *); +int kexgex_client(struct ssh *); +int kexgex_server(struct ssh *); +int kexecdh_client(struct ssh *); +int kexecdh_server(struct ssh *); int kex_dh_hash(char *, char *, char *, size_t, char *, size_t, u_char *, size_t, diff --git a/ssh/kexdhc.c b/ssh/kexdhc.c index f8988c4..f75ec03 100644 --- a/ssh/kexdhc.c +++ b/ssh/kexdhc.c @@ -45,7 +45,7 @@ static int input_kex_dh(int, u_int32_t, struct ssh *); -void +int kexdh_client(struct ssh *ssh) { Kex *kex = ssh->kex; @@ -81,9 +81,9 @@ kexdh_client(struct ssh *ssh) #endif debug("expecting SSH2_MSG_KEXDH_REPLY"); ssh_dispatch_set(ssh, SSH2_MSG_KEXDH_REPLY, &input_kex_dh); - return; + r = 0; out: - fatal("%s: %s", __func__, ssh_err(r)); + return r; } static int diff --git a/ssh/kexdhs.c b/ssh/kexdhs.c index dd11bcf..06194cf 100644 --- a/ssh/kexdhs.c +++ b/ssh/kexdhs.c @@ -48,7 +48,7 @@ static int input_kex_dh_init(int, u_int32_t, struct ssh *); -void +int kexdh_server(struct ssh *ssh) { Kex *kex = ssh->kex; @@ -75,9 +75,9 @@ kexdh_server(struct ssh *ssh) debug("expecting SSH2_MSG_KEXDH_INIT"); ssh_dispatch_set(ssh, SSH2_MSG_KEXDH_INIT, &input_kex_dh_init); - return; + r = 0; out: - fatal("%s: %s", __func__, ssh_err(r)); + return r; } int diff --git a/ssh/kexecdhc.c b/ssh/kexecdhc.c index 72ac8d5..b26026b 100644 --- a/ssh/kexecdhc.c +++ b/ssh/kexecdhc.c @@ -47,7 +47,7 @@ static int input_kex_ecdh_reply(int, u_int32_t, struct ssh *); -void +int kexecdh_client(struct ssh *ssh) { Kex *kex = ssh->kex; @@ -86,11 +86,11 @@ kexecdh_client(struct ssh *ssh) debug("expecting SSH2_MSG_KEX_ECDH_REPLY"); ssh_dispatch_set(ssh, SSH2_MSG_KEX_ECDH_REPLY, &input_kex_ecdh_reply); - return; + r = 0; out: if (client_key) EC_KEY_free(client_key); - fatal("%s: %s", __func__, ssh_err(r)); + return r; } static int diff --git a/ssh/kexecdhs.c b/ssh/kexecdhs.c index 37184c3..14ef5d5 100644 --- a/ssh/kexecdhs.c +++ b/ssh/kexecdhs.c @@ -49,11 +49,12 @@ static int input_kex_ecdh_init(int, u_int32_t, struct ssh *); -void +int kexecdh_server(struct ssh *ssh) { debug("expecting SSH2_MSG_KEX_ECDH_INIT"); ssh_dispatch_set(ssh, SSH2_MSG_KEX_ECDH_INIT, &input_kex_ecdh_init); + return 0; } static int diff --git a/ssh/kexgexc.c b/ssh/kexgexc.c index 5969c2e..d692b00 100644 --- a/ssh/kexgexc.c +++ b/ssh/kexgexc.c @@ -48,7 +48,7 @@ static int input_kex_dh_gex_group(int, u_int32_t, struct ssh *); static int input_kex_dh_gex_reply(int, u_int32_t, struct ssh *); -void +int kexgex_client(struct ssh *ssh) { Kex *kex = ssh->kex; @@ -84,9 +84,9 @@ kexgex_client(struct ssh *ssh) #endif ssh_dispatch_set(ssh, SSH2_MSG_KEX_DH_GEX_GROUP, &input_kex_dh_gex_group); - return; + r = 0; out: - fatal("%s: %s", __func__, ssh_err(r)); + return r; } static int diff --git a/ssh/kexgexs.c b/ssh/kexgexs.c index b8ae3bf..98a0a81 100644 --- a/ssh/kexgexs.c +++ b/ssh/kexgexs.c @@ -52,7 +52,7 @@ static int input_kex_dh_gex_request(int, u_int32_t, struct ssh *); static int input_kex_dh_gex_init(int, u_int32_t, struct ssh *); -void +int kexgex_server(struct ssh *ssh) { ssh_dispatch_set(ssh, SSH2_MSG_KEX_DH_GEX_REQUEST_OLD, @@ -60,7 +60,7 @@ kexgex_server(struct ssh *ssh) ssh_dispatch_set(ssh, SSH2_MSG_KEX_DH_GEX_REQUEST, &input_kex_dh_gex_request); debug("expecting SSH2_MSG_KEX_DH_GEX_REQUEST"); - return; + return 0; } static int