diff options
author | Andreas Baumann <mail@andreasbaumann.cc> | 2015-08-15 21:19:23 +0200 |
---|---|---|
committer | Andreas Baumann <mail@andreasbaumann.cc> | 2015-08-15 21:19:23 +0200 |
commit | 556a78c5e84f6dae156fd91ee7447427a0d081d4 (patch) | |
tree | 369817f89ee5df2cb78e83458037c8c75daf5fc0 | |
parent | 8552cab93e898e184b5887359971e9a62ff705fe (diff) | |
download | cssh-556a78c5e84f6dae156fd91ee7447427a0d081d4.tar.gz cssh-556a78c5e84f6dae156fd91ee7447427a0d081d4.tar.bz2 |
tried to parallelize, nothing but races
-rw-r--r-- | src/cssh.c | 521 | ||||
-rw-r--r-- | src/cssh_options.ggo | 2 |
2 files changed, 344 insertions, 179 deletions
@@ -7,10 +7,13 @@ #include <errno.h> #include <stdbool.h> #include <unistd.h> +#include <sys/types.h> +#include <pwd.h> -#include "cssh_options.h" #include "msleep.h" +#include "cssh_options.h" + #define CSSH_VERSION "0.0.1" static int parse_options_and_arguments( int argc, char *argv[], struct gengetopt_args_info *args_info ) { @@ -163,11 +166,28 @@ static int authenticate_pubkey( ssh_session session ) return -1; } -static int authenticate_password( ssh_session session, const char *user ) +static int authenticate_password( ssh_session session, const char *host, const unsigned short port, const char *user ) { + char prompt[128]; char pass[128]; + + const char *prompt_user; + if( user != NULL ) { + prompt_user = user; + } else { + uid_t uid = getuid( ); + struct passwd *pw = getpwuid( uid ); + if( pw != NULL ) { + prompt_user = pw->pw_name; + } else { + prompt_user = "(unknown)"; + } + } + + snprintf( prompt, sizeof( prompt ), "%s@%s:%d's password:", + prompt_user, host, port ); memset( pass, 0, sizeof( pass ) ); - if( ssh_getpass( "Password: ", pass, sizeof( pass ), 0, 0 ) < 0 ) { + if( ssh_getpass( prompt, pass, sizeof( pass ), 0, 0 ) < 0 ) { fprintf( stderr, "ERROR: ssh_getpass failed\n" ); ssh_disconnect( session ); ssh_free( session ); @@ -198,12 +218,84 @@ static int authenticate_password( ssh_session session, const char *user ) return -1; } +static int read_hosts_file( const char *hosts_file, unsigned short default_port, char ***host, unsigned short **port, unsigned int *nof_hosts ) +{ + FILE *f; + char buf[255]; + + f = fopen( hosts_file, "r" ); + if( f == NULL ) { + fprintf( stderr, "ERROR: Opening the host file '%s' failed: %s\n", + hosts_file, strerror( errno ) ); + return -1; + } + + *nof_hosts = 0; + unsigned int size_hosts = 2; + *host = (char **)malloc( size_hosts * sizeof( char * ) ); + *port = (unsigned short *)malloc( size_hosts * sizeof( unsigned short ) ); + if( *host == NULL || *port == NULL ) { + fprintf( stderr, "ERROR: Memory allocation failed in 'read_hosts_file'" ); + return -1; + } + + while( fgets( buf, sizeof( buf ), f ) != NULL ) { + if( strlen( buf ) > 0 && buf[strlen( buf )-1] == '\n' ) { + buf[strlen(buf)-1] = '\0'; + } + (*host)[*nof_hosts] = strdup( buf ); + // TODO: parse port from file + // TODO: parse user from file (but actually we can use the user@host hack in libssh) + (*port)[*nof_hosts] = default_port; + (*nof_hosts)++; + if( *nof_hosts > size_hosts ) { + size_hosts *= 2; + *host = (char **)realloc( *host, size_hosts * sizeof( char *) ); + *port = (unsigned short *)realloc( *port, size_hosts * sizeof( unsigned short ) ); + if( *host == NULL || *port == NULL ) { + fprintf( stderr, "ERROR: Memory allocation failed in 'read_hosts_file'" ); + return -1; + } + } + } + + fclose( f ); + + return 0; +} + +static void cleanup_sessions( ssh_session **session, ssh_channel **channel, char **host, unsigned short *port, const int nof_sessions, bool verbose ) +{ + for( unsigned int i = 0; i < nof_sessions; i++ ) { + if( ssh_is_connected( (*session)[i] ) ) { + if( verbose ) { + fprintf( stderr, "Disconnecting from '%s', port %d..\n", host[i], port[i] ); + } + if( channel != NULL ) { + if( ssh_channel_is_open( (*channel)[i] ) ) { + ssh_channel_close( (*channel)[i] ); + ssh_channel_free( (*channel)[i] ); + } + } + ssh_disconnect( (*session)[i] ); + if( verbose ) { + fprintf( stderr, "Disconnected from '%s', port %d..\n", host[i], port[i] ); + } + } + ssh_free( (*session)[i] ); + } + if( channel != NULL ) { + free( *channel ); + } + free( *session ); + free( *host ); + free( port ); +} + int main( int argc, char *argv[] ) { struct gengetopt_args_info args_info; - ssh_session session; - int port = 22; - const char *host = "localhost"; + ssh_session *session; int rc; if( parse_options_and_arguments( argc, argv, &args_info ) != 0 ) { @@ -219,95 +311,144 @@ int main( int argc, char *argv[] ) exit( EXIT_SUCCESS ); } - session = ssh_new( ); - if( session == NULL ) { - exit( EXIT_FAILURE ); + unsigned int nof_sessions = 0; + unsigned int command_pos = 0; + char **host = NULL; + unsigned short *port = NULL; + if( args_info.hosts_file_given ) { + unsigned nof_hosts; + unsigned short default_port = 22; + if( args_info.port_given ) { + default_port = args_info.port_arg; + } + rc = read_hosts_file( args_info.hosts_file_arg, default_port, &host, &port, &nof_hosts ); + if( rc < 0 ) { + exit( EXIT_SUCCESS ); + } + nof_sessions = nof_hosts; + } else { + unsigned short default_port = 22; + host = (char **)malloc( sizeof( char * ) ); + port = (unsigned short *)malloc( sizeof( unsigned short ) ); + if( args_info.inputs_num > 1 ) { + command_pos++; + nof_sessions = 1; + host[0] = strdup( args_info.inputs[0] ); + port[0] = default_port; + } else { + host[0] = strdup( "localhost" ); + port[0] = default_port; + } + if( args_info.port_given ) { + port[0] = args_info.port_arg; + } } - - if( args_info.inputs_num > 1 ) { - host = args_info.inputs[0]; + + session = (ssh_session *)malloc( nof_sessions * sizeof( ssh_session ) ); + if( session == NULL ) { + fprintf( stderr, "ERROR: Memory allocation failed for ssh_sessions" ); + return -1; } - ssh_options_set( session, SSH_OPTIONS_HOST, host ); - - if( args_info.port_given ) { - port = args_info.port_arg; + for( unsigned int i = 0; i < nof_sessions; i++ ) { + session[i] = ssh_new( ); + if( session == NULL ) { + exit( EXIT_FAILURE ); + } } - ssh_options_set( session, SSH_OPTIONS_PORT, &port ); - if( args_info.verbose_given ) { - int verbosity = SSH_LOG_NOLOG; + int verbosity = SSH_LOG_NOLOG; + if( args_info.verbose_given > 0 ) { verbosity += args_info.verbose_given; - ssh_options_set( session, SSH_OPTIONS_LOG_VERBOSITY, &verbosity ); } - - rc = ssh_connect( session ); - if( rc != SSH_OK ) { - fprintf( stderr, "ERROR: error connecting to '%s', port '%d': %s\n", - host, port, ssh_get_error( session ) ); - ssh_free( session ); - exit( EXIT_FAILURE ); + + for( unsigned int i = 0; i < nof_sessions; i++ ) { + ssh_options_set( session[i], SSH_OPTIONS_HOST, host[i] ); + ssh_options_set( session[i], SSH_OPTIONS_PORT, &port[i] ); + ssh_options_set( session[i], SSH_OPTIONS_LOG_VERBOSITY, &verbosity ); } - if( verify_knownhost( session ) < 0 ) { - fprintf( stderr, "ERROR: closing connection to '%s', port '%d' due to security reasons\n", - host, port ); - ssh_disconnect( session ); - ssh_free( session ); - exit( EXIT_FAILURE ); - } + for( unsigned int i = 0; i < nof_sessions; i++ ) { + if( args_info.verbose_given ) { + fprintf( stderr, "Connecting to '%s', port %d..\n", host[i], port[i] ); + } + rc = ssh_connect( session[i] ); + if( rc != SSH_OK ) { + fprintf( stderr, "ERROR: error connecting to '%s', port '%d': %s\n", + host[i], port[i], ssh_get_error( session[i] ) ); + cleanup_sessions( &session, NULL, host, port, nof_sessions, args_info.verbose_given > 0 ); + exit( EXIT_FAILURE ); + } + + if( verify_knownhost( session[i] ) < 0 ) { + fprintf( stderr, "ERROR: closing connection to '%s', port '%d' due to security reasons\n", + host[i], port[i] ); + cleanup_sessions( &session, NULL, host, port, nof_sessions, args_info.verbose_given > 0 ); + exit( EXIT_FAILURE ); + } - rc = ssh_userauth_none( session, NULL ); - if( rc == SSH_AUTH_ERROR ) { - fprintf( stderr, "ERROR: ssh_userauth_none to '%s', port '%d' failed: %s\n", - host, port, ssh_get_error( session ) ); - ssh_free( session ); - exit( EXIT_FAILURE ); - } + rc = ssh_userauth_none( session[i], NULL ); + if( rc == SSH_AUTH_ERROR ) { + fprintf( stderr, "ERROR: ssh_userauth_none to '%s', port '%d' failed: %s\n", + host[i], port[i], ssh_get_error( session[i] ) ); + cleanup_sessions( &session, NULL, host, port, nof_sessions, args_info.verbose_given > 0 ); + exit( EXIT_FAILURE ); + } + + char *banner = ssh_get_issue_banner( session[i] ); + if( banner ) { + puts( banner ); + free( banner ); + } - char *banner = ssh_get_issue_banner( session ); - if( banner ) { - puts( banner ); - free( banner ); - } + rc = authenticate_pubkey( session[i] ); + if( rc == SSH_AUTH_ERROR ) { + cleanup_sessions( &session, NULL, host, port, nof_sessions, args_info.verbose_given > 0 ); + exit( EXIT_FAILURE ); + } - rc = authenticate_pubkey( session ); - if( rc == SSH_AUTH_ERROR ) { - ssh_disconnect( session ); - ssh_free( session ); - exit( EXIT_FAILURE ); - } + if( rc != SSH_AUTH_SUCCESS ) { + rc = authenticate_password( session[i], host[i], port[i], args_info.login_given ? args_info.login_arg : NULL ); + if( rc < 0 ) { + cleanup_sessions( &session, NULL, host, port, nof_sessions, args_info.verbose_given > 0 ); + exit( EXIT_FAILURE ); + } + } - if( rc != SSH_AUTH_SUCCESS ) { - rc = authenticate_password( session, args_info.login_given ? args_info.login_arg : NULL ); - if( rc < 0 ) { - ssh_disconnect( session ); - ssh_free( session ); - exit( EXIT_FAILURE ); + if( args_info.verbose_given ) { + fprintf( stderr, "Connected to '%s', port %d..\n", host[i], port[i] ); } } - - ssh_channel channel = ssh_channel_new( session ); + + ssh_channel *channel = (ssh_channel *)malloc( ( nof_sessions + 1 ) * sizeof( ssh_channel ) ); + memset( channel, 0, nof_sessions + 1 ); if( channel == NULL ) { - fprintf( stderr, "ERROR: Unable to open SSH channel: %s\n", - ssh_get_error( session ) ); - ssh_disconnect( session ); - ssh_free( session ); - exit( EXIT_FAILURE ); + fprintf( stderr, "ERROR: Memory allocation failed for ssh_channels" ); + cleanup_sessions( &session, NULL, host, port, nof_sessions, args_info.verbose_given > 0 ); + return -1; + } + for( unsigned int i = 0; i < nof_sessions; i++ ) { + channel[i] = ssh_channel_new( session[i] ); + if( channel[i] == NULL ) { + fprintf( stderr, "ERROR: Unable to open SSH channel: %s\n", + ssh_get_error( session[i] ) ); + cleanup_sessions( &session, &channel, host, port, nof_sessions, args_info.verbose_given > 0 ); + exit( EXIT_FAILURE ); + } } - rc = ssh_channel_open_session( channel ); - if( rc != SSH_OK ) { - ssh_channel_free( channel ); - ssh_disconnect( session ); - ssh_free( session ); - exit( EXIT_FAILURE ); + for( unsigned int i = 0; i < nof_sessions; i++ ) { + rc = ssh_channel_open_session( channel[i] ); + if( rc != SSH_OK ) { + cleanup_sessions( &session, &channel, host, port, nof_sessions, args_info.verbose_given > 0 ); + exit( EXIT_FAILURE ); + } } char cmd[1024]; cmd[0] = '\0'; - if( args_info.inputs_num > 1 ) { - for( int i = 1; i < args_info.inputs_num; i++ ) { - if( i != 1 ) { + if( args_info.inputs_num > 0 ) { + for( int i = command_pos; i < args_info.inputs_num; i++ ) { + if( i != command_pos ) { strncat( cmd, " ", sizeof( cmd ) - strlen( cmd ) - 1 ); } strncat( cmd, args_info.inputs[i], sizeof( cmd ) - strlen( cmd ) - 1 ); @@ -315,133 +456,157 @@ int main( int argc, char *argv[] ) } if( cmd[0] == '\0' ) { fprintf( stderr, "ERROR: Empty command, no interactive CLI supported currently\n" ); - ssh_channel_free( channel ); - ssh_disconnect( session ); - ssh_free( session ); - exit( EXIT_FAILURE ); - } - - rc = ssh_channel_request_exec( channel, cmd ); - if( rc != SSH_OK ) { - ssh_channel_close( channel ); - ssh_channel_free( channel ); - ssh_disconnect( session ); - ssh_free( session ); + cleanup_sessions( &session, &channel, host, port, nof_sessions, args_info.verbose_given > 0 ); exit( EXIT_FAILURE ); } - - while( ssh_channel_is_open( channel ) && !ssh_channel_is_eof( channel ) ) { - char buffer[256]; - bool must_sleep = false; - - rc = ssh_channel_poll( channel, 0 ); - if( rc == SSH_ERROR ) { - fprintf( stderr, "ERROR: ssh_channel_poll on stdout failed: %s\n", - ssh_get_error( session ) ); - ssh_channel_close( channel ); - ssh_channel_free( channel ); - ssh_disconnect( session ); - ssh_free( session ); + + for( unsigned int i = 0; i < nof_sessions; i++ ) { + rc = ssh_channel_request_exec( channel[i], cmd ); + if( rc != SSH_OK ) { + fprintf( stderr, "ERROR: Executing SSH command '%s' failed: %s\n", + cmd, ssh_get_error( session[i] ) ); + cleanup_sessions( &session, &channel, host, port, nof_sessions, args_info.verbose_given > 0 ); exit( EXIT_FAILURE ); } - - if( rc == 0 ) { - cssh_msleep( 10 ); - } - - if( rc > 0 ) { - unsigned int nread = ssh_channel_read_nonblocking( channel, buffer, sizeof( buffer ), 0 ); - if( nread == SSH_ERROR ) { - fprintf( stderr, "ERROR: ssh_channel_read_nonblocking on stdout failed: %s\n", - ssh_get_error( session ) ); - ssh_channel_close( channel ); - ssh_channel_free( channel ); - ssh_disconnect( session ); - ssh_free( session ); - exit( EXIT_FAILURE ); - } + } - if( nread > 0 ) { - size_t wrc = fwrite( buffer, 1, nread, stdout ); - if( wrc < 0 ) { - fprintf( stderr, "ERROR: while writting to stdout: %s\n", - strerror( errno ) ); - ssh_channel_close( channel ); - ssh_channel_free( channel ); - ssh_disconnect( session ); - ssh_free( session ); - exit( EXIT_FAILURE ); - } + ssh_channel *read_channel = (ssh_channel *)malloc( ( nof_sessions + 1 ) * sizeof( ssh_channel ) ); + ssh_channel *write_channel = (ssh_channel *)malloc( ( nof_sessions + 1 ) * sizeof( ssh_channel ) ); + ssh_channel *except_channel = (ssh_channel *)malloc( ( nof_sessions + 1 ) * sizeof( ssh_channel ) ); + bool all_eof = false; + bool *eof_sent = (bool *)malloc( ( nof_sessions + 1 ) * sizeof( bool ) ); + memset( eof_sent, false, ( nof_sessions + 1 ) * sizeof( bool ) ); + + while( !all_eof ) { + struct timeval timeout; + timeout.tv_sec = 1; + timeout.tv_usec = 0; + + all_eof = true; + memset( read_channel, 0, ( nof_sessions + 1 ) * sizeof( ssh_channel ) ); + for( unsigned int i = 0, j = 0; i < nof_sessions; i++ ) { + if( !ssh_channel_is_closed( channel[i] ) && !ssh_channel_is_eof( channel[i] ) ) { + read_channel[j++] = channel[i]; + all_eof = false; + } + } - if( wrc != nread ) { - fprintf( stderr, "ERROR: Write mismatch on stdout (%zu != %d)\n", - wrc, nread ); - ssh_channel_close( channel ); - ssh_channel_free( channel ); - ssh_disconnect( session ); - ssh_free( session ); - exit( EXIT_FAILURE ); - } + memset( write_channel, 0, ( nof_sessions + 1 ) * sizeof( ssh_channel ) ); + for( unsigned int i = 0, j = 0; i < nof_sessions; i++ ) { + if( !ssh_channel_is_closed( channel[i] ) && !eof_sent[i] ) { + write_channel[j++] = channel[i]; } } - rc = ssh_channel_poll( channel, 1 ); + memset( except_channel, 0, ( nof_sessions + 1 ) * sizeof( ssh_channel ) ); + for( unsigned int i = 0, j = 0; i < nof_sessions; i++ ) { + if( !ssh_channel_is_closed( channel[i] ) ) { + except_channel[j++] = channel[i]; + } + } + + rc = ssh_channel_select( read_channel, write_channel, except_channel, &timeout ); if( rc == SSH_ERROR ) { - fprintf( stderr, "ERROR: ssh_channel_poll on stderr failed: %s\n", + fprintf( stderr, "ERROR: ssh_channel_select failed: %s\n", ssh_get_error( session ) ); - ssh_channel_close( channel ); - ssh_channel_free( channel ); - ssh_disconnect( session ); - ssh_free( session ); + cleanup_sessions( &session, &channel, host, port, nof_sessions, args_info.verbose_given > 0 ); exit( EXIT_FAILURE ); + } else if( rc == SSH_EINTR ) { + continue; } - - if( rc > 0 ) { - unsigned int nread = ssh_channel_read_nonblocking( channel, buffer, sizeof( buffer ), 1 ); - if( nread == SSH_ERROR ) { - fprintf( stderr, "ERROR: ssh_channel_read_nonblocking on stderr failed: %s\n", - ssh_get_error( session ) ); - ssh_channel_close( channel ); - ssh_channel_free( channel ); - ssh_disconnect( session ); - ssh_free( session ); - exit( EXIT_FAILURE ); + + // no stdin sent to commands on remote machines for now + for( unsigned int i = 0; i < nof_sessions; i++ ) { + int channel_idx = -1; + for( unsigned int j = 0; j < nof_sessions; j++ ) { + if( channel[j] == write_channel[i] ) { + channel_idx = j; + break; + } } - - if( nread > 0 ) { - size_t wrc = fwrite( buffer, 1, nread, stderr ); - if( wrc < 0 ) { - fprintf( stderr, "ERROR: while writting to stderr: %s\n", - strerror( errno ) ); - ssh_channel_close( channel ); - ssh_channel_free( channel ); - ssh_disconnect( session ); - ssh_free( session ); + if( write_channel[i] != NULL && channel_idx >= 0 && !eof_sent[channel_idx] ) { + ssh_channel_send_eof( write_channel[i] ); + eof_sent[channel_idx] = true; + } + } + + for( unsigned int i = 0; i < nof_sessions; i++ ) { + if( read_channel[i] != NULL ) { + char buffer[4096]; + rc = ssh_channel_poll( read_channel[i], 0 ); + if( rc == SSH_ERROR ) { + fprintf( stderr, "ERROR: ssh_channel_poll on stdout failed: %s\n", + ssh_get_error( ssh_channel_get_session( read_channel[i] ) ) ); + cleanup_sessions( &session, &channel, host, port, nof_sessions, args_info.verbose_given > 0 ); exit( EXIT_FAILURE ); } + + if( rc > 0 ) { + unsigned int nread = ssh_channel_read_nonblocking( read_channel[i], buffer, sizeof( buffer ), 0 ); + if( nread == SSH_ERROR ) { + fprintf( stderr, "ERROR: ssh_channel_read_nonblocking on stdout failed: %s\n", + ssh_get_error( ssh_channel_get_session( read_channel[i] ) ) ); + cleanup_sessions( &session, &channel, host, port, nof_sessions, args_info.verbose_given > 0 ); + exit( EXIT_FAILURE ); + } + + if( nread > 0 ) { + size_t wrc = fwrite( buffer, 1, nread, stdout ); + if( wrc < 0 ) { + fprintf( stderr, "ERROR: while writing to stdout: %s\n", + strerror( errno ) ); + cleanup_sessions( &session, &channel, host, port, nof_sessions, args_info.verbose_given > 0 ); + exit( EXIT_FAILURE ); + } + + if( wrc != nread ) { + fprintf( stderr, "ERROR: Write mismatch on stdout (%zu != %d)\n", + wrc, nread ); + cleanup_sessions( &session, &channel, host, port, nof_sessions, args_info.verbose_given > 0 ); + exit( EXIT_FAILURE ); + } + } + } - if( wrc != nread ) { - fprintf( stderr, "ERROR: Write mismatch on stderr (%zu != %d)\n", - wrc, nread ); - ssh_channel_close( channel ); - ssh_channel_free( channel ); - ssh_disconnect( session ); - ssh_free( session ); + rc = ssh_channel_poll( read_channel[i], 1 ); + if( rc == SSH_ERROR ) { + fprintf( stderr, "ERROR: ssh_channel_poll on stderr failed: %s\n", + ssh_get_error( ssh_channel_get_session( read_channel[i] ) ) ); + cleanup_sessions( &session, &channel, host, port, nof_sessions, args_info.verbose_given > 0 ); exit( EXIT_FAILURE ); } + + if( rc > 0 ) { + unsigned int nread = ssh_channel_read_nonblocking( read_channel[i], buffer, sizeof( buffer ), 1 ); + if( nread == SSH_ERROR ) { + fprintf( stderr, "ERROR: ssh_channel_read_nonblocking on stderr failed: %s\n", + ssh_get_error( ssh_channel_get_session( read_channel[i] ) ) ); + cleanup_sessions( &session, &channel, host, port, nof_sessions, args_info.verbose_given > 0 ); + exit( EXIT_FAILURE ); + } + + if( nread > 0 ) { + size_t wrc = fwrite( buffer, 1, nread, stderr ); + if( wrc < 0 ) { + fprintf( stderr, "ERROR: while writting to stderr: %s\n", + strerror( errno ) ); + cleanup_sessions( &session, &channel, host, port, nof_sessions, args_info.verbose_given > 0 ); + exit( EXIT_FAILURE ); + } + + if( wrc != nread ) { + fprintf( stderr, "ERROR: Write mismatch on stderr (%zu != %d)\n", + wrc, nread ); + cleanup_sessions( &session, &channel, host, port, nof_sessions, args_info.verbose_given > 0 ); + exit( EXIT_FAILURE ); + } + } + } } } - - if( must_sleep ) { - cssh_msleep( 10 ); - } } - ssh_channel_send_eof( channel ); - ssh_channel_close( channel ); - ssh_channel_free( channel ); - ssh_disconnect( session ); - ssh_free( session ); + cleanup_sessions( &session, &channel, host, port, nof_sessions, args_info.verbose_given > 0 ); exit( EXIT_SUCCESS ); } diff --git a/src/cssh_options.ggo b/src/cssh_options.ggo index 3090e57..5594c2b 100644 --- a/src/cssh_options.ggo +++ b/src/cssh_options.ggo @@ -23,7 +23,7 @@ section "Main Options" optional option "hosts-file" H - "List of hosts to use in parallel (listed one per line in a file)" + "List of hosts to use in parallel (listed one per line in a file), optional second parameter separated with a tab can be the port" string typestr="hosts-file" optional |