From c07dd60cdb8eb3fc87b63ed0938979e4e4fb6278 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sat, 30 Oct 2021 02:39:56 +0200 Subject: namedpipe: rename from winpipe to keep in sync with CL299009 Signed-off-by: Jason A. Donenfeld --- ipc/namedpipe/file.go | 286 +++++++++++++++++ ipc/namedpipe/namedpipe.go | 486 +++++++++++++++++++++++++++++ ipc/namedpipe/namedpipe_test.go | 675 ++++++++++++++++++++++++++++++++++++++++ ipc/uapi_windows.go | 8 +- ipc/winpipe/file.go | 286 ----------------- ipc/winpipe/winpipe.go | 474 ---------------------------- ipc/winpipe/winpipe_test.go | 660 --------------------------------------- tun/wintun/dll_windows.go | 128 -------- tun/wintun/session_windows.go | 90 ------ tun/wintun/wintun_windows.go | 150 --------- 10 files changed, 1450 insertions(+), 1793 deletions(-) create mode 100644 ipc/namedpipe/file.go create mode 100644 ipc/namedpipe/namedpipe.go create mode 100644 ipc/namedpipe/namedpipe_test.go delete mode 100644 ipc/winpipe/file.go delete mode 100644 ipc/winpipe/winpipe.go delete mode 100644 ipc/winpipe/winpipe_test.go delete mode 100644 tun/wintun/dll_windows.go delete mode 100644 tun/wintun/session_windows.go delete mode 100644 tun/wintun/wintun_windows.go diff --git a/ipc/namedpipe/file.go b/ipc/namedpipe/file.go new file mode 100644 index 0000000..9c2481d --- /dev/null +++ b/ipc/namedpipe/file.go @@ -0,0 +1,286 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Copyright 2015 Microsoft +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows +// +build windows + +package namedpipe + +import ( + "io" + "os" + "runtime" + "sync" + "sync/atomic" + "time" + "unsafe" + + "golang.org/x/sys/windows" +) + +type timeoutChan chan struct{} + +var ioInitOnce sync.Once +var ioCompletionPort windows.Handle + +// ioResult contains the result of an asynchronous IO operation +type ioResult struct { + bytes uint32 + err error +} + +// ioOperation represents an outstanding asynchronous Win32 IO +type ioOperation struct { + o windows.Overlapped + ch chan ioResult +} + +func initIo() { + h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + if err != nil { + panic(err) + } + ioCompletionPort = h + go ioCompletionProcessor(h) +} + +// file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. +// It takes ownership of this handle and will close it if it is garbage collected. +type file struct { + handle windows.Handle + wg sync.WaitGroup + wgLock sync.RWMutex + closing uint32 // used as atomic boolean + socket bool + readDeadline deadlineHandler + writeDeadline deadlineHandler +} + +type deadlineHandler struct { + setLock sync.Mutex + channel timeoutChan + channelLock sync.RWMutex + timer *time.Timer + timedout uint32 // used as atomic boolean +} + +// makeFile makes a new file from an existing file handle +func makeFile(h windows.Handle) (*file, error) { + f := &file{handle: h} + ioInitOnce.Do(initIo) + _, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0) + if err != nil { + return nil, err + } + err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE) + if err != nil { + return nil, err + } + f.readDeadline.channel = make(timeoutChan) + f.writeDeadline.channel = make(timeoutChan) + return f, nil +} + +// closeHandle closes the resources associated with a Win32 handle +func (f *file) closeHandle() { + f.wgLock.Lock() + // Atomically set that we are closing, releasing the resources only once. + if atomic.SwapUint32(&f.closing, 1) == 0 { + f.wgLock.Unlock() + // cancel all IO and wait for it to complete + windows.CancelIoEx(f.handle, nil) + f.wg.Wait() + // at this point, no new IO can start + windows.Close(f.handle) + f.handle = 0 + } else { + f.wgLock.Unlock() + } +} + +// Close closes a file. +func (f *file) Close() error { + f.closeHandle() + return nil +} + +// prepareIo prepares for a new IO operation. +// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. +func (f *file) prepareIo() (*ioOperation, error) { + f.wgLock.RLock() + if atomic.LoadUint32(&f.closing) == 1 { + f.wgLock.RUnlock() + return nil, os.ErrClosed + } + f.wg.Add(1) + f.wgLock.RUnlock() + c := &ioOperation{} + c.ch = make(chan ioResult) + return c, nil +} + +// ioCompletionProcessor processes completed async IOs forever +func ioCompletionProcessor(h windows.Handle) { + for { + var bytes uint32 + var key uintptr + var op *ioOperation + err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE) + if op == nil { + panic(err) + } + op.ch <- ioResult{bytes, err} + } +} + +// asyncIo processes the return value from ReadFile or WriteFile, blocking until +// the operation has actually completed. +func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { + if err != windows.ERROR_IO_PENDING { + return int(bytes), err + } + + if atomic.LoadUint32(&f.closing) == 1 { + windows.CancelIoEx(f.handle, &c.o) + } + + var timeout timeoutChan + if d != nil { + d.channelLock.Lock() + timeout = d.channel + d.channelLock.Unlock() + } + + var r ioResult + select { + case r = <-c.ch: + err = r.err + if err == windows.ERROR_OPERATION_ABORTED { + if atomic.LoadUint32(&f.closing) == 1 { + err = os.ErrClosed + } + } else if err != nil && f.socket { + // err is from Win32. Query the overlapped structure to get the winsock error. + var bytes, flags uint32 + err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) + } + case <-timeout: + windows.CancelIoEx(f.handle, &c.o) + r = <-c.ch + err = r.err + if err == windows.ERROR_OPERATION_ABORTED { + err = os.ErrDeadlineExceeded + } + } + + // runtime.KeepAlive is needed, as c is passed via native + // code to ioCompletionProcessor, c must remain alive + // until the channel read is complete. + runtime.KeepAlive(c) + return int(r.bytes), err +} + +// Read reads from a file handle. +func (f *file) Read(b []byte) (int, error) { + c, err := f.prepareIo() + if err != nil { + return 0, err + } + defer f.wg.Done() + + if atomic.LoadUint32(&f.readDeadline.timedout) == 1 { + return 0, os.ErrDeadlineExceeded + } + + var bytes uint32 + err = windows.ReadFile(f.handle, b, &bytes, &c.o) + n, err := f.asyncIo(c, &f.readDeadline, bytes, err) + runtime.KeepAlive(b) + + // Handle EOF conditions. + if err == nil && n == 0 && len(b) != 0 { + return 0, io.EOF + } else if err == windows.ERROR_BROKEN_PIPE { + return 0, io.EOF + } else { + return n, err + } +} + +// Write writes to a file handle. +func (f *file) Write(b []byte) (int, error) { + c, err := f.prepareIo() + if err != nil { + return 0, err + } + defer f.wg.Done() + + if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 { + return 0, os.ErrDeadlineExceeded + } + + var bytes uint32 + err = windows.WriteFile(f.handle, b, &bytes, &c.o) + n, err := f.asyncIo(c, &f.writeDeadline, bytes, err) + runtime.KeepAlive(b) + return n, err +} + +func (f *file) SetReadDeadline(deadline time.Time) error { + return f.readDeadline.set(deadline) +} + +func (f *file) SetWriteDeadline(deadline time.Time) error { + return f.writeDeadline.set(deadline) +} + +func (f *file) Flush() error { + return windows.FlushFileBuffers(f.handle) +} + +func (f *file) Fd() uintptr { + return uintptr(f.handle) +} + +func (d *deadlineHandler) set(deadline time.Time) error { + d.setLock.Lock() + defer d.setLock.Unlock() + + if d.timer != nil { + if !d.timer.Stop() { + <-d.channel + } + d.timer = nil + } + atomic.StoreUint32(&d.timedout, 0) + + select { + case <-d.channel: + d.channelLock.Lock() + d.channel = make(chan struct{}) + d.channelLock.Unlock() + default: + } + + if deadline.IsZero() { + return nil + } + + timeoutIO := func() { + atomic.StoreUint32(&d.timedout, 1) + close(d.channel) + } + + now := time.Now() + duration := deadline.Sub(now) + if deadline.After(now) { + // Deadline is in the future, set a timer to wait + d.timer = time.AfterFunc(duration, timeoutIO) + } else { + // Deadline is in the past. Cancel all pending IO now. + timeoutIO() + } + return nil +} diff --git a/ipc/namedpipe/namedpipe.go b/ipc/namedpipe/namedpipe.go new file mode 100644 index 0000000..6db5ea3 --- /dev/null +++ b/ipc/namedpipe/namedpipe.go @@ -0,0 +1,486 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Copyright 2015 Microsoft +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows +// +build windows + +// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes. +package namedpipe + +import ( + "context" + "io" + "net" + "os" + "runtime" + "sync/atomic" + "time" + "unsafe" + + "golang.org/x/sys/windows" +) + +type pipe struct { + *file + path string +} + +type messageBytePipe struct { + pipe + writeClosed int32 + readEOF bool +} + +type pipeAddress string + +func (f *pipe) LocalAddr() net.Addr { + return pipeAddress(f.path) +} + +func (f *pipe) RemoteAddr() net.Addr { + return pipeAddress(f.path) +} + +func (f *pipe) SetDeadline(t time.Time) error { + f.SetReadDeadline(t) + f.SetWriteDeadline(t) + return nil +} + +// CloseWrite closes the write side of a message pipe in byte mode. +func (f *messageBytePipe) CloseWrite() error { + if !atomic.CompareAndSwapInt32(&f.writeClosed, 0, 1) { + return io.ErrClosedPipe + } + err := f.file.Flush() + if err != nil { + atomic.StoreInt32(&f.writeClosed, 0) + return err + } + _, err = f.file.Write(nil) + if err != nil { + atomic.StoreInt32(&f.writeClosed, 0) + return err + } + return nil +} + +// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since +// they are used to implement CloseWrite. +func (f *messageBytePipe) Write(b []byte) (int, error) { + if atomic.LoadInt32(&f.writeClosed) != 0 { + return 0, io.ErrClosedPipe + } + if len(b) == 0 { + return 0, nil + } + return f.file.Write(b) +} + +// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message +// mode pipe will return io.EOF, as will all subsequent reads. +func (f *messageBytePipe) Read(b []byte) (int, error) { + if f.readEOF { + return 0, io.EOF + } + n, err := f.file.Read(b) + if err == io.EOF { + // If this was the result of a zero-byte read, then + // it is possible that the read was due to a zero-size + // message. Since we are simulating CloseWrite with a + // zero-byte message, ensure that all future Read calls + // also return EOF. + f.readEOF = true + } else if err == windows.ERROR_MORE_DATA { + // ERROR_MORE_DATA indicates that the pipe's read mode is message mode + // and the message still has more bytes. Treat this as a success, since + // this package presents all named pipes as byte streams. + err = nil + } + return n, err +} + +func (f *pipe) Handle() windows.Handle { + return f.handle +} + +func (s pipeAddress) Network() string { + return "pipe" +} + +func (s pipeAddress) String() string { + return string(s) +} + +// tryDialPipe attempts to dial the specified pipe until cancellation or timeout. +func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) { + for { + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + path16, err := windows.UTF16PtrFromString(*path) + if err != nil { + return 0, err + } + h, err := windows.CreateFile(path16, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) + if err == nil { + return h, nil + } + if err != windows.ERROR_PIPE_BUSY { + return h, &os.PathError{Err: err, Op: "open", Path: *path} + } + // Wait 10 msec and try again. This is a rather simplistic + // view, as we always try each 10 milliseconds. + time.Sleep(10 * time.Millisecond) + } + } +} + +// DialConfig exposes various options for use in Dial and DialContext. +type DialConfig struct { + ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID. +} + +// DialTimeout connects to the specified named pipe by path, timing out if the +// connection takes longer than the specified duration. If timeout is zero, then +// we use a default timeout of 2 seconds. +func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) { + if timeout == 0 { + timeout = time.Second * 2 + } + absTimeout := time.Now().Add(timeout) + ctx, _ := context.WithDeadline(context.Background(), absTimeout) + conn, err := config.DialContext(ctx, path) + if err == context.DeadlineExceeded { + return nil, os.ErrDeadlineExceeded + } + return conn, err +} + +// DialContext attempts to connect to the specified named pipe by path. +func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) { + var err error + var h windows.Handle + h, err = tryDialPipe(ctx, &path) + if err != nil { + return nil, err + } + + if config.ExpectedOwner != nil { + sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION) + if err != nil { + windows.Close(h) + return nil, err + } + realOwner, _, err := sd.Owner() + if err != nil { + windows.Close(h) + return nil, err + } + if !realOwner.Equals(config.ExpectedOwner) { + windows.Close(h) + return nil, windows.ERROR_ACCESS_DENIED + } + } + + var flags uint32 + err = windows.GetNamedPipeInfo(h, &flags, nil, nil, nil) + if err != nil { + windows.Close(h) + return nil, err + } + + f, err := makeFile(h) + if err != nil { + windows.Close(h) + return nil, err + } + + // If the pipe is in message mode, return a message byte pipe, which + // supports CloseWrite. + if flags&windows.PIPE_TYPE_MESSAGE != 0 { + return &messageBytePipe{ + pipe: pipe{file: f, path: path}, + }, nil + } + return &pipe{file: f, path: path}, nil +} + +var defaultDialer DialConfig + +// DialTimeout calls DialConfig.DialTimeout using an empty configuration. +func DialTimeout(path string, timeout time.Duration) (net.Conn, error) { + return defaultDialer.DialTimeout(path, timeout) +} + +// DialContext calls DialConfig.DialContext using an empty configuration. +func DialContext(ctx context.Context, path string) (net.Conn, error) { + return defaultDialer.DialContext(ctx, path) +} + +type acceptResponse struct { + f *file + err error +} + +type pipeListener struct { + firstHandle windows.Handle + path string + config ListenConfig + acceptCh chan chan acceptResponse + closeCh chan int + doneCh chan int +} + +func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) { + path16, err := windows.UTF16PtrFromString(path) + if err != nil { + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + + var oa windows.OBJECT_ATTRIBUTES + oa.Length = uint32(unsafe.Sizeof(oa)) + + var ntPath windows.NTUnicodeString + if err := windows.RtlDosPathNameToNtPathName(path16, &ntPath, nil, nil); err != nil { + if ntstatus, ok := err.(windows.NTStatus); ok { + err = ntstatus.Errno() + } + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + defer windows.LocalFree(windows.Handle(unsafe.Pointer(ntPath.Buffer))) + oa.ObjectName = &ntPath + + // The security descriptor is only needed for the first pipe. + if isFirstPipe { + if sd != nil { + oa.SecurityDescriptor = sd + } else { + // Construct the default named pipe security descriptor. + var acl *windows.ACL + if err := windows.RtlDefaultNpAcl(&acl); err != nil { + return 0, err + } + defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl))) + sd, err = windows.NewSecurityDescriptor() + if err != nil { + return 0, err + } + if err = sd.SetDACL(acl, true, false); err != nil { + return 0, err + } + oa.SecurityDescriptor = sd + } + } + + typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS) + if c.MessageMode { + typ |= windows.FILE_PIPE_MESSAGE_TYPE + } + + disposition := uint32(windows.FILE_OPEN) + access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) + if isFirstPipe { + disposition = windows.FILE_CREATE + // By not asking for read or write access, the named pipe file system + // will put this pipe into an initially disconnected state, blocking + // client connections until the next call with isFirstPipe == false. + access = windows.SYNCHRONIZE + } + + timeout := int64(-50 * 10000) // 50ms + + var ( + h windows.Handle + iosb windows.IO_STATUS_BLOCK + ) + err = windows.NtCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout) + if err != nil { + if ntstatus, ok := err.(windows.NTStatus); ok { + err = ntstatus.Errno() + } + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + + runtime.KeepAlive(ntPath) + return h, nil +} + +func (l *pipeListener) makeServerPipe() (*file, error) { + h, err := makeServerPipeHandle(l.path, nil, &l.config, false) + if err != nil { + return nil, err + } + f, err := makeFile(h) + if err != nil { + windows.Close(h) + return nil, err + } + return f, nil +} + +func (l *pipeListener) makeConnectedServerPipe() (*file, error) { + p, err := l.makeServerPipe() + if err != nil { + return nil, err + } + + // Wait for the client to connect. + ch := make(chan error) + go func(p *file) { + ch <- connectPipe(p) + }(p) + + select { + case err = <-ch: + if err != nil { + p.Close() + p = nil + } + case <-l.closeCh: + // Abort the connect request by closing the handle. + p.Close() + p = nil + err = <-ch + if err == nil || err == os.ErrClosed { + err = net.ErrClosed + } + } + return p, err +} + +func (l *pipeListener) listenerRoutine() { + closed := false + for !closed { + select { + case <-l.closeCh: + closed = true + case responseCh := <-l.acceptCh: + var ( + p *file + err error + ) + for { + p, err = l.makeConnectedServerPipe() + // If the connection was immediately closed by the client, try + // again. + if err != windows.ERROR_NO_DATA { + break + } + } + responseCh <- acceptResponse{p, err} + closed = err == net.ErrClosed + } + } + windows.Close(l.firstHandle) + l.firstHandle = 0 + // Notify Close and Accept callers that the handle has been closed. + close(l.doneCh) +} + +// ListenConfig contains configuration for the pipe listener. +type ListenConfig struct { + // SecurityDescriptor contains a Windows security descriptor. If nil, the default from RtlDefaultNpAcl is used. + SecurityDescriptor *windows.SECURITY_DESCRIPTOR + + // MessageMode determines whether the pipe is in byte or message mode. In either + // case the pipe is read in byte mode by default. The only practical difference in + // this implementation is that CloseWrite is only supported for message mode pipes; + // CloseWrite is implemented as a zero-byte write, but zero-byte writes are only + // transferred to the reader (and returned as io.EOF in this implementation) + // when the pipe is in message mode. + MessageMode bool + + // InputBufferSize specifies the initial size of the input buffer, in bytes, which the OS will grow as needed. + InputBufferSize int32 + + // OutputBufferSize specifies the initial size of the output buffer, in bytes, which the OS will grow as needed. + OutputBufferSize int32 +} + +// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe. +// The pipe must not already exist. +func (c *ListenConfig) Listen(path string) (net.Listener, error) { + h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true) + if err != nil { + return nil, err + } + l := &pipeListener{ + firstHandle: h, + path: path, + config: *c, + acceptCh: make(chan chan acceptResponse), + closeCh: make(chan int), + doneCh: make(chan int), + } + // The first connection is swallowed on Windows 7 & 8, so synthesize it. + if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) { + path16, err := windows.UTF16PtrFromString(path) + if err == nil { + h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) + if err == nil { + windows.CloseHandle(h) + } + } + } + go l.listenerRoutine() + return l, nil +} + +var defaultListener ListenConfig + +// Listen calls ListenConfig.Listen using an empty configuration. +func Listen(path string) (net.Listener, error) { + return defaultListener.Listen(path) +} + +func connectPipe(p *file) error { + c, err := p.prepareIo() + if err != nil { + return err + } + defer p.wg.Done() + + err = windows.ConnectNamedPipe(p.handle, &c.o) + _, err = p.asyncIo(c, nil, 0, err) + if err != nil && err != windows.ERROR_PIPE_CONNECTED { + return err + } + return nil +} + +func (l *pipeListener) Accept() (net.Conn, error) { + ch := make(chan acceptResponse) + select { + case l.acceptCh <- ch: + response := <-ch + err := response.err + if err != nil { + return nil, err + } + if l.config.MessageMode { + return &messageBytePipe{ + pipe: pipe{file: response.f, path: l.path}, + }, nil + } + return &pipe{file: response.f, path: l.path}, nil + case <-l.doneCh: + return nil, net.ErrClosed + } +} + +func (l *pipeListener) Close() error { + select { + case l.closeCh <- 1: + <-l.doneCh + case <-l.doneCh: + } + return nil +} + +func (l *pipeListener) Addr() net.Addr { + return pipeAddress(l.path) +} diff --git a/ipc/namedpipe/namedpipe_test.go b/ipc/namedpipe/namedpipe_test.go new file mode 100644 index 0000000..0573d0f --- /dev/null +++ b/ipc/namedpipe/namedpipe_test.go @@ -0,0 +1,675 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Copyright 2015 Microsoft +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows +// +build windows + +package namedpipe_test + +import ( + "bufio" + "bytes" + "context" + "errors" + "io" + "net" + "os" + "sync" + "syscall" + "testing" + "time" + + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/ipc/namedpipe" +) + +func randomPipePath() string { + guid, err := windows.GenerateGUID() + if err != nil { + panic(err) + } + return `\\.\PIPE\go-namedpipe-test-` + guid.String() +} + +func TestPingPong(t *testing.T) { + const ( + ping = 42 + pong = 24 + ) + pipePath := randomPipePath() + listener, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatalf("unable to listen on pipe: %v", err) + } + defer listener.Close() + go func() { + incoming, err := listener.Accept() + if err != nil { + t.Fatalf("unable to accept pipe connection: %v", err) + } + defer incoming.Close() + var data [1]byte + _, err = incoming.Read(data[:]) + if err != nil { + t.Fatalf("unable to read ping from pipe: %v", err) + } + if data[0] != ping { + t.Fatalf("expected ping, got %d", data[0]) + } + data[0] = pong + _, err = incoming.Write(data[:]) + if err != nil { + t.Fatalf("unable to write pong to pipe: %v", err) + } + }() + client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err != nil { + t.Fatalf("unable to dial pipe: %v", err) + } + defer client.Close() + client.SetDeadline(time.Now().Add(time.Second * 5)) + var data [1]byte + data[0] = ping + _, err = client.Write(data[:]) + if err != nil { + t.Fatalf("unable to write ping to pipe: %v", err) + } + _, err = client.Read(data[:]) + if err != nil { + t.Fatalf("unable to read pong from pipe: %v", err) + } + if data[0] != pong { + t.Fatalf("expected pong, got %d", data[0]) + } +} + +func TestDialUnknownFailsImmediately(t *testing.T) { + _, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0)) + if !errors.Is(err, syscall.ENOENT) { + t.Fatalf("expected ENOENT got %v", err) + } +} + +func TestDialListenerTimesOut(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond) + if err == nil { + pipe.Close() + } + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } +} + +func TestDialContextListenerTimesOut(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + d := 10 * time.Millisecond + ctx, _ := context.WithTimeout(context.Background(), d) + pipe, err := namedpipe.DialContext(ctx, pipePath) + if err == nil { + pipe.Close() + } + if err != context.DeadlineExceeded { + t.Fatalf("expected context.DeadlineExceeded, got %v", err) + } +} + +func TestDialListenerGetsCancelled(t *testing.T) { + pipePath := randomPipePath() + ctx, cancel := context.WithCancel(context.Background()) + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + ch := make(chan error) + go func(ctx context.Context, ch chan error) { + _, err := namedpipe.DialContext(ctx, pipePath) + ch <- err + }(ctx, ch) + time.Sleep(time.Millisecond * 30) + cancel() + err = <-ch + if err != context.Canceled { + t.Fatalf("expected context.Canceled, got %v", err) + } +} + +func TestDialAccessDeniedWithRestrictedSD(t *testing.T) { + if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil { + t.Skip("dacls on named pipes are broken on wine") + } + pipePath := randomPipePath() + sd, _ := windows.SecurityDescriptorFromString("D:") + l, err := (&namedpipe.ListenConfig{ + SecurityDescriptor: sd, + }).Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err == nil { + pipe.Close() + } + if !errors.Is(err, windows.ERROR_ACCESS_DENIED) { + t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) + } +} + +func getConnection(cfg *namedpipe.ListenConfig) (client net.Conn, server net.Conn, err error) { + pipePath := randomPipePath() + if cfg == nil { + cfg = &namedpipe.ListenConfig{} + } + l, err := cfg.Listen(pipePath) + if err != nil { + return + } + defer l.Close() + + type response struct { + c net.Conn + err error + } + ch := make(chan response) + go func() { + c, err := l.Accept() + ch <- response{c, err} + }() + + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err != nil { + return + } + + r := <-ch + if err = r.err; err != nil { + c.Close() + return + } + + client = c + server = r.c + return +} + +func TestReadTimeout(t *testing.T) { + c, s, err := getConnection(nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + + c.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + + buf := make([]byte, 10) + _, err = c.Read(buf) + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } +} + +func server(l net.Listener, ch chan int) { + c, err := l.Accept() + if err != nil { + panic(err) + } + rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) + s, err := rw.ReadString('\n') + if err != nil { + panic(err) + } + _, err = rw.WriteString("got " + s) + if err != nil { + panic(err) + } + err = rw.Flush() + if err != nil { + panic(err) + } + c.Close() + ch <- 1 +} + +func TestFullListenDialReadWrite(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + ch := make(chan int) + go server(l, ch) + + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) + _, err = rw.WriteString("hello world\n") + if err != nil { + t.Fatal(err) + } + err = rw.Flush() + if err != nil { + t.Fatal(err) + } + + s, err := rw.ReadString('\n') + if err != nil { + t.Fatal(err) + } + ms := "got hello world\n" + if s != ms { + t.Errorf("expected '%s', got '%s'", ms, s) + } + + <-ch +} + +func TestCloseAbortsListen(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + + ch := make(chan error) + go func() { + _, err := l.Accept() + ch <- err + }() + + time.Sleep(30 * time.Millisecond) + l.Close() + + err = <-ch + if err != net.ErrClosed { + t.Fatalf("expected net.ErrClosed, got %v", err) + } +} + +func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) { + b := make([]byte, 10) + w.Close() + n, err := r.Read(b) + if n > 0 { + t.Errorf("unexpected byte count %d", n) + } + if err != io.EOF { + t.Errorf("expected EOF: %v", err) + } +} + +func TestCloseClientEOFServer(t *testing.T) { + c, s, err := getConnection(nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + ensureEOFOnClose(t, c, s) +} + +func TestCloseServerEOFClient(t *testing.T) { + c, s, err := getConnection(nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + ensureEOFOnClose(t, s, c) +} + +func TestCloseWriteEOF(t *testing.T) { + cfg := &namedpipe.ListenConfig{ + MessageMode: true, + } + c, s, err := getConnection(cfg) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + + type closeWriter interface { + CloseWrite() error + } + + err = c.(closeWriter).CloseWrite() + if err != nil { + t.Fatal(err) + } + + b := make([]byte, 10) + _, err = s.Read(b) + if err != io.EOF { + t.Fatal(err) + } +} + +func TestAcceptAfterCloseFails(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + l.Close() + _, err = l.Accept() + if err != net.ErrClosed { + t.Fatalf("expected net.ErrClosed, got %v", err) + } +} + +func TestDialTimesOutByDefault(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds. + if err == nil { + pipe.Close() + } + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } +} + +func TestTimeoutPendingRead(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + serverDone := make(chan struct{}) + + go func() { + s, err := l.Accept() + if err != nil { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + s.Close() + close(serverDone) + }() + + client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + clientErr := make(chan error) + go func() { + buf := make([]byte, 10) + _, err = client.Read(buf) + clientErr <- err + }() + + time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline + client.SetReadDeadline(time.Unix(1, 0)) + + select { + case err = <-clientErr: + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timed out while waiting for read to cancel") + <-clientErr + } + <-serverDone +} + +func TestTimeoutPendingWrite(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + serverDone := make(chan struct{}) + + go func() { + s, err := l.Accept() + if err != nil { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + s.Close() + close(serverDone) + }() + + client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + clientErr := make(chan error) + go func() { + _, err = client.Write([]byte("this should timeout")) + clientErr <- err + }() + + time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline + client.SetWriteDeadline(time.Unix(1, 0)) + + select { + case err = <-clientErr: + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timed out while waiting for write to cancel") + <-clientErr + } + <-serverDone +} + +type CloseWriter interface { + CloseWrite() error +} + +func TestEchoWithMessaging(t *testing.T) { + pipePath := randomPipePath() + l, err := (&namedpipe.ListenConfig{ + MessageMode: true, // Use message mode so that CloseWrite() is supported + InputBufferSize: 65536, // Use 64KB buffers to improve performance + OutputBufferSize: 65536, + }).Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + listenerDone := make(chan bool) + clientDone := make(chan bool) + go func() { + // server echo + conn, err := l.Accept() + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent + _, err = io.Copy(conn, conn) + if err != nil { + t.Fatal(err) + } + conn.(CloseWriter).CloseWrite() + close(listenerDone) + }() + client, err := namedpipe.DialTimeout(pipePath, time.Second) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + go func() { + // client read back + bytes := make([]byte, 2) + n, e := client.Read(bytes) + if e != nil { + t.Fatal(e) + } + if n != 2 || bytes[0] != 0 || bytes[1] != 1 { + t.Fatalf("expected 2 bytes, got %v", n) + } + close(clientDone) + }() + + payload := make([]byte, 2) + payload[0] = 0 + payload[1] = 1 + + n, err := client.Write(payload) + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Fatalf("expected 2 bytes, got %v", n) + } + client.(CloseWriter).CloseWrite() + <-listenerDone + <-clientDone +} + +func TestConnectRace(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + go func() { + for { + s, err := l.Accept() + if err == net.ErrClosed { + return + } + + if err != nil { + t.Fatal(err) + } + s.Close() + } + }() + + for i := 0; i < 1000; i++ { + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err != nil { + t.Fatal(err) + } + c.Close() + } +} + +func TestMessageReadMode(t *testing.T) { + if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { + t.Skipf("Skipping on Windows %d", maj) + } + var wg sync.WaitGroup + defer wg.Wait() + pipePath := randomPipePath() + l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + msg := ([]byte)("hello world") + + wg.Add(1) + go func() { + defer wg.Done() + s, err := l.Accept() + if err != nil { + t.Fatal(err) + } + _, err = s.Write(msg) + if err != nil { + t.Fatal(err) + } + s.Close() + }() + + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + mode := uint32(windows.PIPE_READMODE_MESSAGE) + err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil) + if err != nil { + t.Fatal(err) + } + + ch := make([]byte, 1) + var vmsg []byte + for { + n, err := c.Read(ch) + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatalf("expected 1, got %d", n) + } + vmsg = append(vmsg, ch[0]) + } + if !bytes.Equal(msg, vmsg) { + t.Fatalf("expected %s, got %s", msg, vmsg) + } +} + +func TestListenConnectRace(t *testing.T) { + if testing.Short() { + t.Skip("Skipping long race test") + } + pipePath := randomPipePath() + for i := 0; i < 50 && !t.Failed(); i++ { + var wg sync.WaitGroup + wg.Add(1) + go func() { + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err == nil { + c.Close() + } + wg.Done() + }() + s, err := namedpipe.Listen(pipePath) + if err != nil { + t.Error(i, err) + } else { + s.Close() + } + wg.Wait() + } +} diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go index a4d68da..a1bfbd1 100644 --- a/ipc/uapi_windows.go +++ b/ipc/uapi_windows.go @@ -9,8 +9,7 @@ import ( "net" "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/ipc/winpipe" + "golang.zx2c4.com/wireguard/ipc/namedpipe" ) // TODO: replace these with actual standard windows error numbers from the win package @@ -61,10 +60,9 @@ func init() { } func UAPIListen(name string) (net.Listener, error) { - config := winpipe.ListenConfig{ + listener, err := (&namedpipe.ListenConfig{ SecurityDescriptor: UAPISecurityDescriptor, - } - listener, err := winpipe.Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config) + }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name) if err != nil { return nil, err } diff --git a/ipc/winpipe/file.go b/ipc/winpipe/file.go deleted file mode 100644 index 319565f..0000000 --- a/ipc/winpipe/file.go +++ /dev/null @@ -1,286 +0,0 @@ -//go:build windows - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package winpipe - -import ( - "io" - "os" - "runtime" - "sync" - "sync/atomic" - "time" - "unsafe" - - "golang.org/x/sys/windows" -) - -type timeoutChan chan struct{} - -var ioInitOnce sync.Once -var ioCompletionPort windows.Handle - -// ioResult contains the result of an asynchronous IO operation -type ioResult struct { - bytes uint32 - err error -} - -// ioOperation represents an outstanding asynchronous Win32 IO -type ioOperation struct { - o windows.Overlapped - ch chan ioResult -} - -func initIo() { - h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) - if err != nil { - panic(err) - } - ioCompletionPort = h - go ioCompletionProcessor(h) -} - -// file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. -// It takes ownership of this handle and will close it if it is garbage collected. -type file struct { - handle windows.Handle - wg sync.WaitGroup - wgLock sync.RWMutex - closing uint32 // used as atomic boolean - socket bool - readDeadline deadlineHandler - writeDeadline deadlineHandler -} - -type deadlineHandler struct { - setLock sync.Mutex - channel timeoutChan - channelLock sync.RWMutex - timer *time.Timer - timedout uint32 // used as atomic boolean -} - -// makeFile makes a new file from an existing file handle -func makeFile(h windows.Handle) (*file, error) { - f := &file{handle: h} - ioInitOnce.Do(initIo) - _, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0) - if err != nil { - return nil, err - } - err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE) - if err != nil { - return nil, err - } - f.readDeadline.channel = make(timeoutChan) - f.writeDeadline.channel = make(timeoutChan) - return f, nil -} - -// closeHandle closes the resources associated with a Win32 handle -func (f *file) closeHandle() { - f.wgLock.Lock() - // Atomically set that we are closing, releasing the resources only once. - if atomic.SwapUint32(&f.closing, 1) == 0 { - f.wgLock.Unlock() - // cancel all IO and wait for it to complete - windows.CancelIoEx(f.handle, nil) - f.wg.Wait() - // at this point, no new IO can start - windows.Close(f.handle) - f.handle = 0 - } else { - f.wgLock.Unlock() - } -} - -// Close closes a file. -func (f *file) Close() error { - f.closeHandle() - return nil -} - -// prepareIo prepares for a new IO operation. -// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. -func (f *file) prepareIo() (*ioOperation, error) { - f.wgLock.RLock() - if atomic.LoadUint32(&f.closing) == 1 { - f.wgLock.RUnlock() - return nil, os.ErrClosed - } - f.wg.Add(1) - f.wgLock.RUnlock() - c := &ioOperation{} - c.ch = make(chan ioResult) - return c, nil -} - -// ioCompletionProcessor processes completed async IOs forever -func ioCompletionProcessor(h windows.Handle) { - for { - var bytes uint32 - var key uintptr - var op *ioOperation - err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE) - if op == nil { - panic(err) - } - op.ch <- ioResult{bytes, err} - } -} - -// asyncIo processes the return value from ReadFile or WriteFile, blocking until -// the operation has actually completed. -func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { - if err != windows.ERROR_IO_PENDING { - return int(bytes), err - } - - if atomic.LoadUint32(&f.closing) == 1 { - windows.CancelIoEx(f.handle, &c.o) - } - - var timeout timeoutChan - if d != nil { - d.channelLock.Lock() - timeout = d.channel - d.channelLock.Unlock() - } - - var r ioResult - select { - case r = <-c.ch: - err = r.err - if err == windows.ERROR_OPERATION_ABORTED { - if atomic.LoadUint32(&f.closing) == 1 { - err = os.ErrClosed - } - } else if err != nil && f.socket { - // err is from Win32. Query the overlapped structure to get the winsock error. - var bytes, flags uint32 - err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) - } - case <-timeout: - windows.CancelIoEx(f.handle, &c.o) - r = <-c.ch - err = r.err - if err == windows.ERROR_OPERATION_ABORTED { - err = os.ErrDeadlineExceeded - } - } - - // runtime.KeepAlive is needed, as c is passed via native - // code to ioCompletionProcessor, c must remain alive - // until the channel read is complete. - runtime.KeepAlive(c) - return int(r.bytes), err -} - -// Read reads from a file handle. -func (f *file) Read(b []byte) (int, error) { - c, err := f.prepareIo() - if err != nil { - return 0, err - } - defer f.wg.Done() - - if atomic.LoadUint32(&f.readDeadline.timedout) == 1 { - return 0, os.ErrDeadlineExceeded - } - - var bytes uint32 - err = windows.ReadFile(f.handle, b, &bytes, &c.o) - n, err := f.asyncIo(c, &f.readDeadline, bytes, err) - runtime.KeepAlive(b) - - // Handle EOF conditions. - if err == nil && n == 0 && len(b) != 0 { - return 0, io.EOF - } else if err == windows.ERROR_BROKEN_PIPE { - return 0, io.EOF - } else { - return n, err - } -} - -// Write writes to a file handle. -func (f *file) Write(b []byte) (int, error) { - c, err := f.prepareIo() - if err != nil { - return 0, err - } - defer f.wg.Done() - - if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 { - return 0, os.ErrDeadlineExceeded - } - - var bytes uint32 - err = windows.WriteFile(f.handle, b, &bytes, &c.o) - n, err := f.asyncIo(c, &f.writeDeadline, bytes, err) - runtime.KeepAlive(b) - return n, err -} - -func (f *file) SetReadDeadline(deadline time.Time) error { - return f.readDeadline.set(deadline) -} - -func (f *file) SetWriteDeadline(deadline time.Time) error { - return f.writeDeadline.set(deadline) -} - -func (f *file) Flush() error { - return windows.FlushFileBuffers(f.handle) -} - -func (f *file) Fd() uintptr { - return uintptr(f.handle) -} - -func (d *deadlineHandler) set(deadline time.Time) error { - d.setLock.Lock() - defer d.setLock.Unlock() - - if d.timer != nil { - if !d.timer.Stop() { - <-d.channel - } - d.timer = nil - } - atomic.StoreUint32(&d.timedout, 0) - - select { - case <-d.channel: - d.channelLock.Lock() - d.channel = make(chan struct{}) - d.channelLock.Unlock() - default: - } - - if deadline.IsZero() { - return nil - } - - timeoutIO := func() { - atomic.StoreUint32(&d.timedout, 1) - close(d.channel) - } - - now := time.Now() - duration := deadline.Sub(now) - if deadline.After(now) { - // Deadline is in the future, set a timer to wait - d.timer = time.AfterFunc(duration, timeoutIO) - } else { - // Deadline is in the past. Cancel all pending IO now. - timeoutIO() - } - return nil -} diff --git a/ipc/winpipe/winpipe.go b/ipc/winpipe/winpipe.go deleted file mode 100644 index e3719d6..0000000 --- a/ipc/winpipe/winpipe.go +++ /dev/null @@ -1,474 +0,0 @@ -//go:build windows - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -// Package winpipe implements a net.Conn and net.Listener around Windows named pipes. -package winpipe - -import ( - "context" - "io" - "net" - "os" - "runtime" - "time" - "unsafe" - - "golang.org/x/sys/windows" -) - -type pipe struct { - *file - path string -} - -type messageBytePipe struct { - pipe - writeClosed bool - readEOF bool -} - -type pipeAddress string - -func (f *pipe) LocalAddr() net.Addr { - return pipeAddress(f.path) -} - -func (f *pipe) RemoteAddr() net.Addr { - return pipeAddress(f.path) -} - -func (f *pipe) SetDeadline(t time.Time) error { - f.SetReadDeadline(t) - f.SetWriteDeadline(t) - return nil -} - -// CloseWrite closes the write side of a message pipe in byte mode. -func (f *messageBytePipe) CloseWrite() error { - if f.writeClosed { - return io.ErrClosedPipe - } - err := f.file.Flush() - if err != nil { - return err - } - _, err = f.file.Write(nil) - if err != nil { - return err - } - f.writeClosed = true - return nil -} - -// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since -// they are used to implement CloseWrite. -func (f *messageBytePipe) Write(b []byte) (int, error) { - if f.writeClosed { - return 0, io.ErrClosedPipe - } - if len(b) == 0 { - return 0, nil - } - return f.file.Write(b) -} - -// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message -// mode pipe will return io.EOF, as will all subsequent reads. -func (f *messageBytePipe) Read(b []byte) (int, error) { - if f.readEOF { - return 0, io.EOF - } - n, err := f.file.Read(b) - if err == io.EOF { - // If this was the result of a zero-byte read, then - // it is possible that the read was due to a zero-size - // message. Since we are simulating CloseWrite with a - // zero-byte message, ensure that all future Read calls - // also return EOF. - f.readEOF = true - } else if err == windows.ERROR_MORE_DATA { - // ERROR_MORE_DATA indicates that the pipe's read mode is message mode - // and the message still has more bytes. Treat this as a success, since - // this package presents all named pipes as byte streams. - err = nil - } - return n, err -} - -func (f *pipe) Handle() windows.Handle { - return f.handle -} - -func (s pipeAddress) Network() string { - return "pipe" -} - -func (s pipeAddress) String() string { - return string(s) -} - -// tryDialPipe attempts to dial the specified pipe until cancellation or timeout. -func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) { - for { - select { - case <-ctx.Done(): - return 0, ctx.Err() - default: - path16, err := windows.UTF16PtrFromString(*path) - if err != nil { - return 0, err - } - h, err := windows.CreateFile(path16, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) - if err == nil { - return h, nil - } - if err != windows.ERROR_PIPE_BUSY { - return h, &os.PathError{Err: err, Op: "open", Path: *path} - } - // Wait 10 msec and try again. This is a rather simplistic - // view, as we always try each 10 milliseconds. - time.Sleep(10 * time.Millisecond) - } - } -} - -// DialConfig exposes various options for use in Dial and DialContext. -type DialConfig struct { - ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID. -} - -// Dial connects to the specified named pipe by path, timing out if the connection -// takes longer than the specified duration. If timeout is nil, then we use -// a default timeout of 2 seconds. -func Dial(path string, timeout *time.Duration, config *DialConfig) (net.Conn, error) { - var absTimeout time.Time - if timeout != nil { - absTimeout = time.Now().Add(*timeout) - } else { - absTimeout = time.Now().Add(2 * time.Second) - } - ctx, _ := context.WithDeadline(context.Background(), absTimeout) - conn, err := DialContext(ctx, path, config) - if err == context.DeadlineExceeded { - return nil, os.ErrDeadlineExceeded - } - return conn, err -} - -// DialContext attempts to connect to the specified named pipe by path -// cancellation or timeout. -func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn, error) { - if config == nil { - config = &DialConfig{} - } - var err error - var h windows.Handle - h, err = tryDialPipe(ctx, &path) - if err != nil { - return nil, err - } - - if config.ExpectedOwner != nil { - sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION) - if err != nil { - windows.Close(h) - return nil, err - } - realOwner, _, err := sd.Owner() - if err != nil { - windows.Close(h) - return nil, err - } - if !realOwner.Equals(config.ExpectedOwner) { - windows.Close(h) - return nil, windows.ERROR_ACCESS_DENIED - } - } - - var flags uint32 - err = windows.GetNamedPipeInfo(h, &flags, nil, nil, nil) - if err != nil { - windows.Close(h) - return nil, err - } - - f, err := makeFile(h) - if err != nil { - windows.Close(h) - return nil, err - } - - // If the pipe is in message mode, return a message byte pipe, which - // supports CloseWrite. - if flags&windows.PIPE_TYPE_MESSAGE != 0 { - return &messageBytePipe{ - pipe: pipe{file: f, path: path}, - }, nil - } - return &pipe{file: f, path: path}, nil -} - -type acceptResponse struct { - f *file - err error -} - -type pipeListener struct { - firstHandle windows.Handle - path string - config ListenConfig - acceptCh chan (chan acceptResponse) - closeCh chan int - doneCh chan int -} - -func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, first bool) (windows.Handle, error) { - path16, err := windows.UTF16PtrFromString(path) - if err != nil { - return 0, &os.PathError{Op: "open", Path: path, Err: err} - } - - var oa windows.OBJECT_ATTRIBUTES - oa.Length = uint32(unsafe.Sizeof(oa)) - - var ntPath windows.NTUnicodeString - if err := windows.RtlDosPathNameToNtPathName(path16, &ntPath, nil, nil); err != nil { - if ntstatus, ok := err.(windows.NTStatus); ok { - err = ntstatus.Errno() - } - return 0, &os.PathError{Op: "open", Path: path, Err: err} - } - defer windows.LocalFree(windows.Handle(unsafe.Pointer(ntPath.Buffer))) - oa.ObjectName = &ntPath - - // The security descriptor is only needed for the first pipe. - if first { - if sd != nil { - oa.SecurityDescriptor = sd - } else { - // Construct the default named pipe security descriptor. - var acl *windows.ACL - if err := windows.RtlDefaultNpAcl(&acl); err != nil { - return 0, err - } - defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl))) - sd, err := windows.NewSecurityDescriptor() - if err != nil { - return 0, err - } - if err = sd.SetDACL(acl, true, false); err != nil { - return 0, err - } - oa.SecurityDescriptor = sd - } - } - - typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS) - if c.MessageMode { - typ |= windows.FILE_PIPE_MESSAGE_TYPE - } - - disposition := uint32(windows.FILE_OPEN) - access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) - if first { - disposition = windows.FILE_CREATE - // By not asking for read or write access, the named pipe file system - // will put this pipe into an initially disconnected state, blocking - // client connections until the next call with first == false. - access = windows.SYNCHRONIZE - } - - timeout := int64(-50 * 10000) // 50ms - - var ( - h windows.Handle - iosb windows.IO_STATUS_BLOCK - ) - err = windows.NtCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout) - if err != nil { - if ntstatus, ok := err.(windows.NTStatus); ok { - err = ntstatus.Errno() - } - return 0, &os.PathError{Op: "open", Path: path, Err: err} - } - - runtime.KeepAlive(ntPath) - return h, nil -} - -func (l *pipeListener) makeServerPipe() (*file, error) { - h, err := makeServerPipeHandle(l.path, nil, &l.config, false) - if err != nil { - return nil, err - } - f, err := makeFile(h) - if err != nil { - windows.Close(h) - return nil, err - } - return f, nil -} - -func (l *pipeListener) makeConnectedServerPipe() (*file, error) { - p, err := l.makeServerPipe() - if err != nil { - return nil, err - } - - // Wait for the client to connect. - ch := make(chan error) - go func(p *file) { - ch <- connectPipe(p) - }(p) - - select { - case err = <-ch: - if err != nil { - p.Close() - p = nil - } - case <-l.closeCh: - // Abort the connect request by closing the handle. - p.Close() - p = nil - err = <-ch - if err == nil || err == os.ErrClosed { - err = net.ErrClosed - } - } - return p, err -} - -func (l *pipeListener) listenerRoutine() { - closed := false - for !closed { - select { - case <-l.closeCh: - closed = true - case responseCh := <-l.acceptCh: - var ( - p *file - err error - ) - for { - p, err = l.makeConnectedServerPipe() - // If the connection was immediately closed by the client, try - // again. - if err != windows.ERROR_NO_DATA { - break - } - } - responseCh <- acceptResponse{p, err} - closed = err == net.ErrClosed - } - } - windows.Close(l.firstHandle) - l.firstHandle = 0 - // Notify Close and Accept callers that the handle has been closed. - close(l.doneCh) -} - -// ListenConfig contains configuration for the pipe listener. -type ListenConfig struct { - // SecurityDescriptor contains a Windows security descriptor. If nil, the default from RtlDefaultNpAcl is used. - SecurityDescriptor *windows.SECURITY_DESCRIPTOR - - // MessageMode determines whether the pipe is in byte or message mode. In either - // case the pipe is read in byte mode by default. The only practical difference in - // this implementation is that CloseWrite is only supported for message mode pipes; - // CloseWrite is implemented as a zero-byte write, but zero-byte writes are only - // transferred to the reader (and returned as io.EOF in this implementation) - // when the pipe is in message mode. - MessageMode bool - - // InputBufferSize specifies the initial size of the input buffer, in bytes, which the OS will grow as needed. - InputBufferSize int32 - - // OutputBufferSize specifies the initial size of the output buffer, in bytes, which the OS will grow as needed. - OutputBufferSize int32 -} - -// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe. -// The pipe must not already exist. -func Listen(path string, c *ListenConfig) (net.Listener, error) { - if c == nil { - c = &ListenConfig{} - } - h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true) - if err != nil { - return nil, err - } - l := &pipeListener{ - firstHandle: h, - path: path, - config: *c, - acceptCh: make(chan (chan acceptResponse)), - closeCh: make(chan int), - doneCh: make(chan int), - } - // The first connection is swallowed on Windows 7 & 8, so synthesize it. - if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { - path16, err := windows.UTF16PtrFromString(path) - if err == nil { - h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) - if err == nil { - windows.CloseHandle(h) - } - } - } - go l.listenerRoutine() - return l, nil -} - -func connectPipe(p *file) error { - c, err := p.prepareIo() - if err != nil { - return err - } - defer p.wg.Done() - - err = windows.ConnectNamedPipe(p.handle, &c.o) - _, err = p.asyncIo(c, nil, 0, err) - if err != nil && err != windows.ERROR_PIPE_CONNECTED { - return err - } - return nil -} - -func (l *pipeListener) Accept() (net.Conn, error) { - ch := make(chan acceptResponse) - select { - case l.acceptCh <- ch: - response := <-ch - err := response.err - if err != nil { - return nil, err - } - if l.config.MessageMode { - return &messageBytePipe{ - pipe: pipe{file: response.f, path: l.path}, - }, nil - } - return &pipe{file: response.f, path: l.path}, nil - case <-l.doneCh: - return nil, net.ErrClosed - } -} - -func (l *pipeListener) Close() error { - select { - case l.closeCh <- 1: - <-l.doneCh - case <-l.doneCh: - } - return nil -} - -func (l *pipeListener) Addr() net.Addr { - return pipeAddress(l.path) -} diff --git a/ipc/winpipe/winpipe_test.go b/ipc/winpipe/winpipe_test.go deleted file mode 100644 index ea515e3..0000000 --- a/ipc/winpipe/winpipe_test.go +++ /dev/null @@ -1,660 +0,0 @@ -//go:build windows - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package winpipe_test - -import ( - "bufio" - "bytes" - "context" - "errors" - "io" - "net" - "os" - "sync" - "syscall" - "testing" - "time" - - "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/ipc/winpipe" -) - -func randomPipePath() string { - guid, err := windows.GenerateGUID() - if err != nil { - panic(err) - } - return `\\.\PIPE\go-winpipe-test-` + guid.String() -} - -func TestPingPong(t *testing.T) { - const ( - ping = 42 - pong = 24 - ) - pipePath := randomPipePath() - listener, err := winpipe.Listen(pipePath, nil) - if err != nil { - t.Fatalf("unable to listen on pipe: %v", err) - } - defer listener.Close() - go func() { - incoming, err := listener.Accept() - if err != nil { - t.Fatalf("unable to accept pipe connection: %v", err) - } - defer incoming.Close() - var data [1]byte - _, err = incoming.Read(data[:]) - if err != nil { - t.Fatalf("unable to read ping from pipe: %v", err) - } - if data[0] != ping { - t.Fatalf("expected ping, got %d", data[0]) - } - data[0] = pong - _, err = incoming.Write(data[:]) - if err != nil { - t.Fatalf("unable to write pong to pipe: %v", err) - } - }() - client, err := winpipe.Dial(pipePath, nil, nil) - if err != nil { - t.Fatalf("unable to dial pipe: %v", err) - } - defer client.Close() - var data [1]byte - data[0] = ping - _, err = client.Write(data[:]) - if err != nil { - t.Fatalf("unable to write ping to pipe: %v", err) - } - _, err = client.Read(data[:]) - if err != nil { - t.Fatalf("unable to read pong from pipe: %v", err) - } - if data[0] != pong { - t.Fatalf("expected pong, got %d", data[0]) - } -} - -func TestDialUnknownFailsImmediately(t *testing.T) { - _, err := winpipe.Dial(randomPipePath(), nil, nil) - if !errors.Is(err, syscall.ENOENT) { - t.Fatalf("expected ENOENT got %v", err) - } -} - -func TestDialListenerTimesOut(t *testing.T) { - pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) - if err != nil { - t.Fatal(err) - } - defer l.Close() - d := 10 * time.Millisecond - _, err = winpipe.Dial(pipePath, &d, nil) - if err != os.ErrDeadlineExceeded { - t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) - } -} - -func TestDialContextListenerTimesOut(t *testing.T) { - pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) - if err != nil { - t.Fatal(err) - } - defer l.Close() - d := 10 * time.Millisecond - ctx, _ := context.WithTimeout(context.Background(), d) - _, err = winpipe.DialContext(ctx, pipePath, nil) - if err != context.DeadlineExceeded { - t.Fatalf("expected context.DeadlineExceeded, got %v", err) - } -} - -func TestDialListenerGetsCancelled(t *testing.T) { - pipePath := randomPipePath() - ctx, cancel := context.WithCancel(context.Background()) - l, err := winpipe.Listen(pipePath, nil) - if err != nil { - t.Fatal(err) - } - ch := make(chan error) - defer l.Close() - go func(ctx context.Context, ch chan error) { - _, err := winpipe.DialContext(ctx, pipePath, nil) - ch <- err - }(ctx, ch) - time.Sleep(time.Millisecond * 30) - cancel() - err = <-ch - if err != context.Canceled { - t.Fatalf("expected context.Canceled, got %v", err) - } -} - -func TestDialAccessDeniedWithRestrictedSD(t *testing.T) { - if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil { - t.Skip("dacls on named pipes are broken on wine") - } - pipePath := randomPipePath() - sd, _ := windows.SecurityDescriptorFromString("D:") - c := winpipe.ListenConfig{ - SecurityDescriptor: sd, - } - l, err := winpipe.Listen(pipePath, &c) - if err != nil { - t.Fatal(err) - } - defer l.Close() - _, err = winpipe.Dial(pipePath, nil, nil) - if !errors.Is(err, windows.ERROR_ACCESS_DENIED) { - t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) - } -} - -func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, err error) { - pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, cfg) - if err != nil { - return - } - defer l.Close() - - type response struct { - c net.Conn - err error - } - ch := make(chan response) - go func() { - c, err := l.Accept() - ch <- response{c, err} - }() - - c, err := winpipe.Dial(pipePath, nil, nil) - if err != nil { - return - } - - r := <-ch - if err = r.err; err != nil { - c.Close() - return - } - - client = c - server = r.c - return -} - -func TestReadTimeout(t *testing.T) { - c, s, err := getConnection(nil) - if err != nil { - t.Fatal(err) - } - defer c.Close() - defer s.Close() - - c.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) - - buf := make([]byte, 10) - _, err = c.Read(buf) - if err != os.ErrDeadlineExceeded { - t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) - } -} - -func server(l net.Listener, ch chan int) { - c, err := l.Accept() - if err != nil { - panic(err) - } - rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) - s, err := rw.ReadString('\n') - if err != nil { - panic(err) - } - _, err = rw.WriteString("got " + s) - if err != nil { - panic(err) - } - err = rw.Flush() - if err != nil { - panic(err) - } - c.Close() - ch <- 1 -} - -func TestFullListenDialReadWrite(t *testing.T) { - pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) - if err != nil { - t.Fatal(err) - } - defer l.Close() - - ch := make(chan int) - go server(l, ch) - - c, err := winpipe.Dial(pipePath, nil, nil) - if err != nil { - t.Fatal(err) - } - defer c.Close() - - rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) - _, err = rw.WriteString("hello world\n") - if err != nil { - t.Fatal(err) - } - err = rw.Flush() - if err != nil { - t.Fatal(err) - } - - s, err := rw.ReadString('\n') - if err != nil { - t.Fatal(err) - } - ms := "got hello world\n" - if s != ms { - t.Errorf("expected '%s', got '%s'", ms, s) - } - - <-ch -} - -func TestCloseAbortsListen(t *testing.T) { - pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) - if err != nil { - t.Fatal(err) - } - - ch := make(chan error) - go func() { - _, err := l.Accept() - ch <- err - }() - - time.Sleep(30 * time.Millisecond) - l.Close() - - err = <-ch - if err != net.ErrClosed { - t.Fatalf("expected net.ErrClosed, got %v", err) - } -} - -func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) { - b := make([]byte, 10) - w.Close() - n, err := r.Read(b) - if n > 0 { - t.Errorf("unexpected byte count %d", n) - } - if err != io.EOF { - t.Errorf("expected EOF: %v", err) - } -} - -func TestCloseClientEOFServer(t *testing.T) { - c, s, err := getConnection(nil) - if err != nil { - t.Fatal(err) - } - defer c.Close() - defer s.Close() - ensureEOFOnClose(t, c, s) -} - -func TestCloseServerEOFClient(t *testing.T) { - c, s, err := getConnection(nil) - if err != nil { - t.Fatal(err) - } - defer c.Close() - defer s.Close() - ensureEOFOnClose(t, s, c) -} - -func TestCloseWriteEOF(t *testing.T) { - cfg := &winpipe.ListenConfig{ - MessageMode: true, - } - c, s, err := getConnection(cfg) - if err != nil { - t.Fatal(err) - } - defer c.Close() - defer s.Close() - - type closeWriter interface { - CloseWrite() error - } - - err = c.(closeWriter).CloseWrite() - if err != nil { - t.Fatal(err) - } - - b := make([]byte, 10) - _, err = s.Read(b) - if err != io.EOF { - t.Fatal(err) - } -} - -func TestAcceptAfterCloseFails(t *testing.T) { - pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) - if err != nil { - t.Fatal(err) - } - l.Close() - _, err = l.Accept() - if err != net.ErrClosed { - t.Fatalf("expected net.ErrClosed, got %v", err) - } -} - -func TestDialTimesOutByDefault(t *testing.T) { - pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) - if err != nil { - t.Fatal(err) - } - defer l.Close() - _, err = winpipe.Dial(pipePath, nil, nil) - if err != os.ErrDeadlineExceeded { - t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) - } -} - -func TestTimeoutPendingRead(t *testing.T) { - pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) - if err != nil { - t.Fatal(err) - } - defer l.Close() - - serverDone := make(chan struct{}) - - go func() { - s, err := l.Accept() - if err != nil { - t.Fatal(err) - } - time.Sleep(1 * time.Second) - s.Close() - close(serverDone) - }() - - client, err := winpipe.Dial(pipePath, nil, nil) - if err != nil { - t.Fatal(err) - } - defer client.Close() - - clientErr := make(chan error) - go func() { - buf := make([]byte, 10) - _, err = client.Read(buf) - clientErr <- err - }() - - time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline - client.SetReadDeadline(time.Unix(1, 0)) - - select { - case err = <-clientErr: - if err != os.ErrDeadlineExceeded { - t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timed out while waiting for read to cancel") - <-clientErr - } - <-serverDone -} - -func TestTimeoutPendingWrite(t *testing.T) { - pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) - if err != nil { - t.Fatal(err) - } - defer l.Close() - - serverDone := make(chan struct{}) - - go func() { - s, err := l.Accept() - if err != nil { - t.Fatal(err) - } - time.Sleep(1 * time.Second) - s.Close() - close(serverDone) - }() - - client, err := winpipe.Dial(pipePath, nil, nil) - if err != nil { - t.Fatal(err) - } - defer client.Close() - - clientErr := make(chan error) - go func() { - _, err = client.Write([]byte("this should timeout")) - clientErr <- err - }() - - time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline - client.SetWriteDeadline(time.Unix(1, 0)) - - select { - case err = <-clientErr: - if err != os.ErrDeadlineExceeded { - t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timed out while waiting for write to cancel") - <-clientErr - } - <-serverDone -} - -type CloseWriter interface { - CloseWrite() error -} - -func TestEchoWithMessaging(t *testing.T) { - c := winpipe.ListenConfig{ - MessageMode: true, // Use message mode so that CloseWrite() is supported - InputBufferSize: 65536, // Use 64KB buffers to improve performance - OutputBufferSize: 65536, - } - pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, &c) - if err != nil { - t.Fatal(err) - } - defer l.Close() - - listenerDone := make(chan bool) - clientDone := make(chan bool) - go func() { - // server echo - conn, e := l.Accept() - if e != nil { - t.Fatal(e) - } - defer conn.Close() - - time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent - io.Copy(conn, conn) - conn.(CloseWriter).CloseWrite() - close(listenerDone) - }() - timeout := 1 * time.Second - client, err := winpipe.Dial(pipePath, &timeout, nil) - if err != nil { - t.Fatal(err) - } - defer client.Close() - - go func() { - // client read back - bytes := make([]byte, 2) - n, e := client.Read(bytes) - if e != nil { - t.Fatal(e) - } - if n != 2 { - t.Fatalf("expected 2 bytes, got %v", n) - } - close(clientDone) - }() - - payload := make([]byte, 2) - payload[0] = 0 - payload[1] = 1 - - n, err := client.Write(payload) - if err != nil { - t.Fatal(err) - } - if n != 2 { - t.Fatalf("expected 2 bytes, got %v", n) - } - client.(CloseWriter).CloseWrite() - <-listenerDone - <-clientDone -} - -func TestConnectRace(t *testing.T) { - pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) - if err != nil { - t.Fatal(err) - } - defer l.Close() - go func() { - for { - s, err := l.Accept() - if err == net.ErrClosed { - return - } - - if err != nil { - t.Fatal(err) - } - s.Close() - } - }() - - for i := 0; i < 1000; i++ { - c, err := winpipe.Dial(pipePath, nil, nil) - if err != nil { - t.Fatal(err) - } - c.Close() - } -} - -func TestMessageReadMode(t *testing.T) { - if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { - t.Skipf("Skipping on Windows %d", maj) - } - var wg sync.WaitGroup - defer wg.Wait() - pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, &winpipe.ListenConfig{MessageMode: true}) - if err != nil { - t.Fatal(err) - } - defer l.Close() - - msg := ([]byte)("hello world") - - wg.Add(1) - go func() { - defer wg.Done() - s, err := l.Accept() - if err != nil { - t.Fatal(err) - } - _, err = s.Write(msg) - if err != nil { - t.Fatal(err) - } - s.Close() - }() - - c, err := winpipe.Dial(pipePath, nil, nil) - if err != nil { - t.Fatal(err) - } - defer c.Close() - - mode := uint32(windows.PIPE_READMODE_MESSAGE) - err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil) - if err != nil { - t.Fatal(err) - } - - ch := make([]byte, 1) - var vmsg []byte - for { - n, err := c.Read(ch) - if err == io.EOF { - break - } - if err != nil { - t.Fatal(err) - } - if n != 1 { - t.Fatalf("expected 1, got %d", n) - } - vmsg = append(vmsg, ch[0]) - } - if !bytes.Equal(msg, vmsg) { - t.Fatalf("expected %s, got %s", msg, vmsg) - } -} - -func TestListenConnectRace(t *testing.T) { - if testing.Short() { - t.Skip("Skipping long race test") - } - pipePath := randomPipePath() - for i := 0; i < 50 && !t.Failed(); i++ { - var wg sync.WaitGroup - wg.Add(1) - go func() { - c, err := winpipe.Dial(pipePath, nil, nil) - if err == nil { - c.Close() - } - wg.Done() - }() - s, err := winpipe.Listen(pipePath, nil) - if err != nil { - t.Error(i, err) - } else { - s.Close() - } - wg.Wait() - } -} diff --git a/tun/wintun/dll_windows.go b/tun/wintun/dll_windows.go deleted file mode 100644 index 3832c1e..0000000 --- a/tun/wintun/dll_windows.go +++ /dev/null @@ -1,128 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package wintun - -import ( - "fmt" - "sync" - "sync/atomic" - "unsafe" - - "golang.org/x/sys/windows" -) - -func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL { - return &lazyDLL{Name: name, onLoad: onLoad} -} - -func (d *lazyDLL) NewProc(name string) *lazyProc { - return &lazyProc{dll: d, Name: name} -} - -type lazyProc struct { - Name string - mu sync.Mutex - dll *lazyDLL - addr uintptr -} - -func (p *lazyProc) Find() error { - if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil { - return nil - } - p.mu.Lock() - defer p.mu.Unlock() - if p.addr != 0 { - return nil - } - - err := p.dll.Load() - if err != nil { - return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err) - } - addr, err := p.nameToAddr() - if err != nil { - return fmt.Errorf("Error getting %v address: %w", p.Name, err) - } - - atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr)) - return nil -} - -func (p *lazyProc) Addr() uintptr { - err := p.Find() - if err != nil { - panic(err) - } - return p.addr -} - -type lazyDLL struct { - Name string - mu sync.Mutex - module windows.Handle - onLoad func(d *lazyDLL) -} - -func (d *lazyDLL) Load() error { - if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil { - return nil - } - d.mu.Lock() - defer d.mu.Unlock() - if d.module != 0 { - return nil - } - - const ( - LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200 - LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800 - ) - module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32) - if err != nil { - return fmt.Errorf("Unable to load library: %w", err) - } - - atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module)) - if d.onLoad != nil { - d.onLoad(d) - } - return nil -} - -func (p *lazyProc) nameToAddr() (uintptr, error) { - return windows.GetProcAddress(p.dll.module, p.Name) -} - -// Version returns the version of the Wintun DLL. -func Version() string { - if modwintun.Load() != nil { - return "unknown" - } - resInfo, err := windows.FindResource(modwintun.module, windows.ResourceID(1), windows.RT_VERSION) - if err != nil { - return "unknown" - } - data, err := windows.LoadResourceData(modwintun.module, resInfo) - if err != nil { - return "unknown" - } - - var fixedInfo *windows.VS_FIXEDFILEINFO - fixedInfoLen := uint32(unsafe.Sizeof(*fixedInfo)) - err = windows.VerQueryValue(unsafe.Pointer(&data[0]), `\`, unsafe.Pointer(&fixedInfo), &fixedInfoLen) - if err != nil { - return "unknown" - } - version := fmt.Sprintf("%d.%d", (fixedInfo.FileVersionMS>>16)&0xff, (fixedInfo.FileVersionMS>>0)&0xff) - if nextNibble := (fixedInfo.FileVersionLS >> 16) & 0xff; nextNibble != 0 { - version += fmt.Sprintf(".%d", nextNibble) - } - if nextNibble := (fixedInfo.FileVersionLS >> 0) & 0xff; nextNibble != 0 { - version += fmt.Sprintf(".%d", nextNibble) - } - return version -} diff --git a/tun/wintun/session_windows.go b/tun/wintun/session_windows.go deleted file mode 100644 index f023baf..0000000 --- a/tun/wintun/session_windows.go +++ /dev/null @@ -1,90 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package wintun - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -type Session struct { - handle uintptr -} - -const ( - PacketSizeMax = 0xffff // Maximum packet size - RingCapacityMin = 0x20000 // Minimum ring capacity (128 kiB) - RingCapacityMax = 0x4000000 // Maximum ring capacity (64 MiB) -) - -// Packet with data -type Packet struct { - Next *Packet // Pointer to next packet in queue - Size uint32 // Size of packet (max WINTUN_MAX_IP_PACKET_SIZE) - Data *[PacketSizeMax]byte // Pointer to layer 3 IPv4 or IPv6 packet -} - -var ( - procWintunAllocateSendPacket = modwintun.NewProc("WintunAllocateSendPacket") - procWintunEndSession = modwintun.NewProc("WintunEndSession") - procWintunGetReadWaitEvent = modwintun.NewProc("WintunGetReadWaitEvent") - procWintunReceivePacket = modwintun.NewProc("WintunReceivePacket") - procWintunReleaseReceivePacket = modwintun.NewProc("WintunReleaseReceivePacket") - procWintunSendPacket = modwintun.NewProc("WintunSendPacket") - procWintunStartSession = modwintun.NewProc("WintunStartSession") -) - -func (wintun *Adapter) StartSession(capacity uint32) (session Session, err error) { - r0, _, e1 := syscall.Syscall(procWintunStartSession.Addr(), 2, uintptr(wintun.handle), uintptr(capacity), 0) - if r0 == 0 { - err = e1 - } else { - session = Session{r0} - } - return -} - -func (session Session) End() { - syscall.Syscall(procWintunEndSession.Addr(), 1, session.handle, 0, 0) - session.handle = 0 -} - -func (session Session) ReadWaitEvent() (handle windows.Handle) { - r0, _, _ := syscall.Syscall(procWintunGetReadWaitEvent.Addr(), 1, session.handle, 0, 0) - handle = windows.Handle(r0) - return -} - -func (session Session) ReceivePacket() (packet []byte, err error) { - var packetSize uint32 - r0, _, e1 := syscall.Syscall(procWintunReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packetSize)), 0) - if r0 == 0 { - err = e1 - return - } - packet = unsafe.Slice((*byte)(unsafe.Pointer(r0)), packetSize) - return -} - -func (session Session) ReleaseReceivePacket(packet []byte) { - syscall.Syscall(procWintunReleaseReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0) -} - -func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err error) { - r0, _, e1 := syscall.Syscall(procWintunAllocateSendPacket.Addr(), 2, session.handle, uintptr(packetSize), 0) - if r0 == 0 { - err = e1 - return - } - packet = unsafe.Slice((*byte)(unsafe.Pointer(r0)), packetSize) - return -} - -func (session Session) SendPacket(packet []byte) { - syscall.Syscall(procWintunSendPacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0) -} diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go deleted file mode 100644 index 2fe26a7..0000000 --- a/tun/wintun/wintun_windows.go +++ /dev/null @@ -1,150 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package wintun - -import ( - "log" - "runtime" - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -type loggerLevel int - -const ( - logInfo loggerLevel = iota - logWarn - logErr -) - -const AdapterNameMax = 128 - -type Adapter struct { - handle uintptr -} - -var ( - modwintun = newLazyDLL("wintun.dll", setupLogger) - procWintunCreateAdapter = modwintun.NewProc("WintunCreateAdapter") - procWintunOpenAdapter = modwintun.NewProc("WintunOpenAdapter") - procWintunCloseAdapter = modwintun.NewProc("WintunCloseAdapter") - procWintunDeleteDriver = modwintun.NewProc("WintunDeleteDriver") - procWintunGetAdapterLUID = modwintun.NewProc("WintunGetAdapterLUID") - procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion") -) - -type TimestampedWriter interface { - WriteWithTimestamp(p []byte, ts int64) (n int, err error) -} - -func logMessage(level loggerLevel, timestamp uint64, msg *uint16) int { - if tw, ok := log.Default().Writer().(TimestampedWriter); ok { - tw.WriteWithTimestamp([]byte(log.Default().Prefix()+windows.UTF16PtrToString(msg)), (int64(timestamp)-116444736000000000)*100) - } else { - log.Println(windows.UTF16PtrToString(msg)) - } - return 0 -} - -func setupLogger(dll *lazyDLL) { - var callback uintptr - if runtime.GOARCH == "386" { - callback = windows.NewCallback(func(level loggerLevel, timestampLow, timestampHigh uint32, msg *uint16) int { - return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg) - }) - } else if runtime.GOARCH == "arm" { - callback = windows.NewCallback(func(level loggerLevel, _, timestampLow, timestampHigh uint32, msg *uint16) int { - return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg) - }) - } else if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" { - callback = windows.NewCallback(logMessage) - } - syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, callback, 0, 0) -} - -func closeAdapter(wintun *Adapter) { - syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0) -} - -// CreateAdapter creates a Wintun adapter. name is the cosmetic name of the adapter. -// tunnelType represents the type of adapter and should be "Wintun". requestedGUID is -// the GUID of the created network adapter, which then influences NLA generation -// deterministically. If it is set to nil, the GUID is chosen by the system at random, -// and hence a new NLA entry is created for each new adapter. -func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) { - var name16 *uint16 - name16, err = windows.UTF16PtrFromString(name) - if err != nil { - return - } - var tunnelType16 *uint16 - tunnelType16, err = windows.UTF16PtrFromString(tunnelType) - if err != nil { - return - } - r0, _, e1 := syscall.Syscall(procWintunCreateAdapter.Addr(), 3, uintptr(unsafe.Pointer(name16)), uintptr(unsafe.Pointer(tunnelType16)), uintptr(unsafe.Pointer(requestedGUID))) - if r0 == 0 { - err = e1 - return - } - wintun = &Adapter{handle: r0} - runtime.SetFinalizer(wintun, closeAdapter) - return -} - -// OpenAdapter opens an existing Wintun adapter by name. -func OpenAdapter(name string) (wintun *Adapter, err error) { - var name16 *uint16 - name16, err = windows.UTF16PtrFromString(name) - if err != nil { - return - } - r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 1, uintptr(unsafe.Pointer(name16)), 0, 0) - if r0 == 0 { - err = e1 - return - } - wintun = &Adapter{handle: r0} - runtime.SetFinalizer(wintun, closeAdapter) - return -} - -// Close closes a Wintun adapter. -func (wintun *Adapter) Close() (err error) { - runtime.SetFinalizer(wintun, nil) - r1, _, e1 := syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0) - if r1 == 0 { - err = e1 - } - return -} - -// Uninstall removes the driver from the system if no drivers are currently in use. -func Uninstall() (err error) { - r1, _, e1 := syscall.Syscall(procWintunDeleteDriver.Addr(), 0, 0, 0, 0) - if r1 == 0 { - err = e1 - } - return -} - -// RunningVersion returns the version of the loaded driver. -func RunningVersion() (version uint32, err error) { - r0, _, e1 := syscall.Syscall(procWintunGetRunningDriverVersion.Addr(), 0, 0, 0, 0) - version = uint32(r0) - if version == 0 { - err = e1 - } - return -} - -// LUID returns the LUID of the adapter. -func (wintun *Adapter) LUID() (luid uint64) { - syscall.Syscall(procWintunGetAdapterLUID.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&luid)), 0) - return -} -- cgit v1.2.3-54-g00ecf