diff --git a/network/udp/listen_unix.go b/network/udp/listen_unix.go index d32e7e5..7e65331 100644 --- a/network/udp/listen_unix.go +++ b/network/udp/listen_unix.go @@ -4,64 +4,64 @@ package udp import ( + "context" "fmt" "net" + "syscall" "golang.org/x/sys/unix" ) +// listenUDP crée workerCount sockets UDP distinctes avec SO_REUSEPORT activé. +// Chaque socket écoute sur bindAddr et sera utilisée par un worker. func listenUDP(bindAddr string, workerCount int) ([]*net.UDPConn, error) { - addr, err := net.ResolveUDPAddr("udp", bindAddr) - if err != nil { - return nil, fmt.Errorf("cannot resolve address: %w", err) + if workerCount <= 0 { + workerCount = 1 } - conns := make([]*net.UDPConn, 0, workerCount) + var conns []*net.UDPConn + + closeAll := func() { + for _, c := range conns { + _ = c.Close() + } + } + + lc := net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + var sockErr error + err := c.Control(func(fd uintptr) { + if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil { + sockErr = fmt.Errorf("set SO_REUSEADDR: %w", err) + return + } + if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { + sockErr = fmt.Errorf("set SO_REUSEPORT: %w", err) + return + } + }) + if err != nil { + return err + } + return sockErr + }, + } for i := 0; i < workerCount; i++ { - conn, err := net.ListenUDP("udp", addr) + pc, err := lc.ListenPacket(context.Background(), "udp", bindAddr) if err != nil { - // En cas d’erreur, on ferme ce qu’on a ouvert. - for _, c := range conns { - _ = c.Close() - } - return nil, fmt.Errorf("cannot listen on address (worker %d): %w", i, err) + closeAll() + return nil, fmt.Errorf("worker %d: listen failed on %q: %w", i, bindAddr, err) } - rawConn, err := conn.SyscallConn() - if err != nil { - _ = conn.Close() - for _, c := range conns { - _ = c.Close() - } - return nil, fmt.Errorf("cannot get raw connection (worker %d): %w", i, err) + udpConn, ok := pc.(*net.UDPConn) + if !ok { + pc.Close() + closeAll() + return nil, fmt.Errorf("worker %d: PacketConn is not *net.UDPConn", i) } - var sockErr error - err = rawConn.Control(func(fd uintptr) { - sockErr = unix.SetsockoptInt( - int(fd), - unix.SOL_SOCKET, - unix.SO_REUSEPORT, - 1, - ) - }) - if err != nil { - _ = conn.Close() - for _, c := range conns { - _ = c.Close() - } - return nil, fmt.Errorf("control error (worker %d): %w", i, err) - } - if sockErr != nil { - _ = conn.Close() - for _, c := range conns { - _ = c.Close() - } - return nil, fmt.Errorf("cannot set SO_REUSEPORT (worker %d): %w", i, sockErr) - } - - conns = append(conns, conn) + conns = append(conns, udpConn) } return conns, nil diff --git a/network/udp/listen_windows.go b/network/udp/listen_windows.go index 1b25cea..dc83679 100644 --- a/network/udp/listen_windows.go +++ b/network/udp/listen_windows.go @@ -8,17 +8,18 @@ import ( "net" ) +// listenUDP sur Windows crée une seule socket UDP qui sera partagée par tous +// les workers (tous les workerID utiliseront conns[0]). func listenUDP(bindAddr string, workerCount int) ([]*net.UDPConn, error) { addr, err := net.ResolveUDPAddr("udp", bindAddr) if err != nil { - return nil, fmt.Errorf("cannot resolve address: %w", err) + return nil, fmt.Errorf("cannot resolve address %q: %w", bindAddr, err) } conn, err := net.ListenUDP("udp", addr) if err != nil { - return nil, fmt.Errorf("cannot listen on address: %w", err) + return nil, fmt.Errorf("cannot listen on address %q: %w", bindAddr, err) } - // Un seul conn partagé par tous les workers. return []*net.UDPConn{conn}, nil } diff --git a/network/udp/routing_table.go b/network/udp/routing_table.go index d5ef1d3..0445d3d 100644 --- a/network/udp/routing_table.go +++ b/network/udp/routing_table.go @@ -7,49 +7,61 @@ import ( mapset "github.com/deckarep/golang-set/v2" ) +// ... existing code ... + type RoutingTable struct { mu sync.RWMutex routes map[string]mapset.Set[*net.UDPAddr] } +// NewRoutingTable crée une table vide. func NewRoutingTable() *RoutingTable { return &RoutingTable{ routes: make(map[string]mapset.Set[*net.UDPAddr]), } } -func (rt *RoutingTable) AddClient(channelID string, addr *net.UDPAddr) { +// Add enregistre un client dans un channel. +func (rt *RoutingTable) Add(channel string, addr *net.UDPAddr) { rt.mu.Lock() defer rt.mu.Unlock() - if rt.routes[channelID] == nil { - rt.routes[channelID] = mapset.NewSet[*net.UDPAddr]() + set, ok := rt.routes[channel] + if !ok { + set = mapset.NewSet[*net.UDPAddr]() + rt.routes[channel] = set } - rt.routes[channelID].Add(addr) + set.Add(addr) } -func (rt *RoutingTable) RemoveClient(channelID string, addr *net.UDPAddr) { +// Remove supprime un client d'un channel. +func (rt *RoutingTable) Remove(channel string, addr *net.UDPAddr) { rt.mu.Lock() defer rt.mu.Unlock() - if clients, exists := rt.routes[channelID]; exists { - clients.Remove(addr) - if clients.Cardinality() == 0 { - delete(rt.routes, channelID) - } + set, ok := rt.routes[channel] + if !ok { + return + } + set.Remove(addr) + if set.Cardinality() == 0 { + delete(rt.routes, channel) } } -// GetClients returns the clients connected to the given channelID -// don't modify the returned set! -func (rt *RoutingTable) GetClients(channelID string) mapset.Set[*net.UDPAddr] { +// GetAddrs renvoie une copie de la liste des clients d'un channel. +func (rt *RoutingTable) GetAddrs(channel string) []*net.UDPAddr { rt.mu.RLock() defer rt.mu.RUnlock() - clients, exists := rt.routes[channelID] - if !exists { + set, ok := rt.routes[channel] + if !ok { return nil } - return clients + addrs := make([]*net.UDPAddr, 0, set.Cardinality()) + for addr := range set.Iter() { + addrs = append(addrs, addr) + } + return addrs } diff --git a/network/udp/server.go b/network/udp/server.go index eacd036..9be38d0 100644 --- a/network/udp/server.go +++ b/network/udp/server.go @@ -22,13 +22,22 @@ func NewServer(bindAddr string) *Server { ctx, cancel := context.WithCancel(context.Background()) return &Server{ - bindAddr: bindAddr, - ctx: ctx, - cancel: cancel, + bindAddr: bindAddr, + routingTable: NewRoutingTable(), + ctx: ctx, + cancel: cancel, } } +func (s *Server) Router() *RoutingTable { + return s.routingTable +} + func (s *Server) Run() error { + if s.bindAddr == "" { + return fmt.Errorf("bind address is empty") + } + workerCount := runtime.NumCPU() conns, err := listenUDP(s.bindAddr, workerCount) @@ -41,78 +50,102 @@ func (s *Server) Run() error { s.conns = conns for _, conn := range s.conns { - conn.SetReadBuffer(8 * 1024 * 1024) - conn.SetWriteBuffer(8 * 1024 * 1024) + _ = conn.SetReadBuffer(8 * 1024 * 1024) + _ = conn.SetWriteBuffer(8 * 1024 * 1024) } - fmt.Println("Listening on", s.bindAddr) + fmt.Println("[udp] listening on", s.bindAddr, "with", len(s.conns), "worker(s)") - for i, conn := range s.conns { + for workerID, conn := range s.conns { s.wg.Add(1) - go s.workerLoop(i, conn) + go s.workerLoop(workerID, conn) } return nil } -func (s *Server) sendTo(data []byte, addr *net.UDPAddr) error { - if len(s.conns) == 0 || s.conns[0] == nil { - return fmt.Errorf("server not started") +func (s *Server) Stop() { + s.cancel() + for _, c := range s.conns { + _ = c.Close() } - - // On utilise la première conn pour l’envoi (c’est suffisant pour UDP). - _, err := s.conns[0].WriteToUDP(data, addr) - return err + s.wg.Wait() } -func (s *Server) workerLoop(id int, conn *net.UDPConn) { +func (s *Server) workerLoop(workerID int, conn *net.UDPConn) { defer s.wg.Done() - buffer := make([]byte, 1500) - fmt.Println("Worker", id, "started") + buf := make([]byte, 1500) + fmt.Println("[udp] worker", workerID, "started") for { select { case <-s.ctx.Done(): - fmt.Println("Worker", id, "stopped") + fmt.Println("[udp] worker", workerID, "stopped") return default: - size, addr, err := conn.ReadFromUDP(buffer) + n, addr, err := conn.ReadFromUDP(buf) if err != nil { if s.ctx.Err() != nil { return } - if opErr, ok := err.(*net.OpError); ok && opErr.Temporary() { + if ne, ok := err.(net.Error); ok && ne.Timeout() { continue } - fmt.Printf("Error reading from UDP (worker %d): %v\n", id, err) + fmt.Printf("[udp] worker %d: read error: %v\n", workerID, err) continue } - s.handlePacket(buffer[:size], addr) + // Le worker qui lit est aussi celui qui traite et qui écrit. + s.handlePacket(conn, buf[:n], addr) } } } -func (s *Server) handlePacket(data []byte, addr *net.UDPAddr) { +// handlePacket reçoit maintenant directement la conn du worker. +// Aucun hop supplémentaire : le paquet reste dans la goroutine/conn du worker. +func (s *Server) handlePacket(conn *net.UDPConn, data []byte, addr *net.UDPAddr) { if len(data) == 0 { return } pt := PacketType(data[0]) + switch pt { case PacketTypePing: - _ = s.sendTo([]byte{byte(PacketTypePing)}, addr) - return + // exemple simple : echo du ping + _, _ = conn.WriteToUDP([]byte{byte(PacketTypePing)}, addr) + case PacketTypeConnect: - return + if len(data) < 2 { + return + } + channelID := string(data[1:]) + s.routingTable.Add(channelID, addr) + case PacketTypeDisconnect: - return + if len(data) < 2 { + return + } + channelID := string(data[1:]) + s.routingTable.Remove(channelID, addr) + case PacketTypeVoiceData: - // todo : déterminer le format du packet - // channelID := string(data[1:5]) - return + if len(data) < 2 { + return + } + channelID := string(data[1:]) // à adapter selon ton vrai format + recipients := s.routingTable.GetAddrs(channelID) + + for _, dst := range recipients { + // optionnel: ne pas renvoyer à la source + if dst.IP.Equal(addr.IP) && dst.Port == addr.Port { + continue + } + _, _ = conn.WriteToUDP(data, dst) + } + default: - return + // type inconnu -> ignore } }