简体   繁体   中英

How do I find out which socket client send the message?

I am working on a chat program. I have a server and client, multiple users can connect to the server. Currently, I just have the server send back whatever message the clients send to the server. I would like to add on an authentication so that I can accept/decline the connection if the authentication fails.

client:

class Network:
    # initialize the socket
    def __init__(self, client, host=host, port=port):
        self.client = client;
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM);
        self.port = port;
        self.host = host;
        self.addr = (host, port);

    # conenct to the server
    def connect(self):
        self.socket.connect(self.addr);

    # receive data from server if there is any
    def read(self):
        while True:
            time.sleep(0.1)
            try:
                data = self.socket.recv(1024);
            except:
                break;
                # instead of breaking, create "connection lost" then open the login form again
            print "in client: ", data;
            data_split = data.split("\r\n");
            for ds in data_split:
                self.client.msgbox.addMsg(ds);

    # send chat message to the server
    def send(self, msg):
        self.socket.send(msg);

    # authenticate user
    # if
    def authenticate(self, info):
        self.socket.send(info);

server:

class Server:
    # init the socket
    def __init__(self, host=host, port=port):
        self.host = host;
        self.port = port;
        self.addr = (host, port);
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM);

    # send data to client
    def send(self, soc, data):
        try:
            soc.send(data);
        except:
            return "couldn't send message";

    # receive data from client
    def receive(self, soc):
        while True:
            try:
                return soc.recv(size);
            except:
                return disconnect;

    # connect client
    def connect(self):
        self.socket.bind(self.addr);
        self.socket.listen(5);
        self.socket_s = [self.socket];
        self.read_socs = [self.socket];
        self.write_socs = [];
        self.user_addr = {};

    # validate the user
    def validate(self, username, password):
        if username in users:
            sha = s256.new();
            sha.update(password);
            password = sha.hexdigest();

            if password == users[username]:
                print "in server: true";
                return True;
            else:
                print "in server: false";
                return False;

    # server
    def serve(self):
        while True:
            r_socs, w_socs, exceptions = select.select(self.read_socs, [], []);
            for s in r_socs:
                if s in self.socket_s:
                    print "accepting socket connect";
                    soc, address = s.accept();
                    print "in server: ", soc, address;
                    self.read_socs.append(soc);
                    self.write_socs.append(soc);
                    for ws in self.write_socs:
                        self.send(ws, "len(users) == " + str(len(self.write_socs)) + "\n");
                        print connection;
                else:
                    data = self.receive(s);
                    print "in server: " + data;
                    if auth in data:
                        ds = data.split(" ");
                        res = self.validate(ds[1], ds[2]);
                    elif data == disconnect:
                        s.close();
                        self.read_socs.remove(s);
                        self.write_socs.remove(s);
                        for ws in self.write_socs:
                            print "in server: " + ws
                            self.send(ws, "len(users) == " + str(len(self.write_socs)) + "\n");
                    else:
                        for ws in self.write_socs:
                            print "in server: " + ws;
                            self.send(ws, data);

Your design is not actually going to work, because the data in a TCP message received doesn't necessarily correlate with a single send from the other side—it could be half a message, or 3 messages, or 5-1/2 messages. If you're just testing on localhost, with small messages, it will often seem to work in your tests, and then completely fail when you put it on the internet. That's why you need to build some kind of protocol on top of TCP that uses delimiters (like newlines), length prefixes (like netstrings), or self-delimiting objects (like JSON).

At any rate, you know the socket each message comes in on. You can map sockets to users, or just use the sockets themselves, or their fds, to make decisions. So, just as you keep track of all the known sockets to pass to select , you also keep track of all sockets known to be authenticated. If the socket a message comes in on is in that list, it's authenticated; otherwise, the message is rejected unless it's an auth message.

Let's say you've got a simple line protocol:

def __init__(self):
    self.sockets = [] # add clients here, along with listener
    self.authsockets = [] # add authenticated clients here
    self.buffers = defaultdict(str)

def loop(self):
    r, w, x = select.select([sockets], [sockets], [sockets])
    for sock in r:
        buffers[sock] = buffers[sock] + sock.recv(4096)
        lines = buffers[sock].split('\n')
        if buffers[sock][-1] != '\n':
            buffers[sock], lines = lines[-1], lines[:-1]
        else:
            buffers[sock] = ''
        for line in lines:
            processCommand(sock, line)
    # etc.

def processCommand(self, sock, command):
    if self.isAuthCommand(command):
        if self.isValidAuthCommand(command):
            self.authsockets.append(sock)
        return
    if not sock in self.authsockets:
        return # ignore commands before auth
    self.doNormalThing(command)

I've stripped out all of the irrelevant stuff—handling accepts, disconnects, errors, writes, etc. But you've got a similar problem there to your reads. First, you're assuming that sockets are always writable, which is not true. You need to queue up a write buffer for each socket, and write when select tells you it's OK. Again, this may seem to work on localhost, but it will fall apart on the internet. Second, writing to a socket may not send the entire buffer, so you need to look at how many bytes got written and keep buffer[bytecount:] around until next time.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM