/*
* Copyright 2006-2010, Haiku, Inc. All Rights Reserved.
* Distributed under the terms of the MIT License.
*
* Authors:
* Axel Dörfler, axeld@pinc-software.de
*/
#include "stack_private.h"
#include <stdlib.h>
#include <string.h>
#include <sys/ioctl.h>
#include <sys/time.h>
#include <new>
#include <Drivers.h>
#include <KernelExport.h>
#include <Select.h>
#include <AutoDeleter.h>
#include <team.h>
#include <util/AutoLock.h>
#include <util/list.h>
#include <WeakReferenceable.h>
#include <fs/select_sync_pool.h>
#include <kernel.h>
#include <net_protocol.h>
#include <net_stack.h>
#include <net_stat.h>
#include "ancillary_data.h"
#include "utility.h"
//#define TRACE_SOCKET
#ifdef TRACE_SOCKET
# define TRACE(x...) dprintf(STACK_DEBUG_PREFIX x)
#else
# define TRACE(x...) ;
#endif
struct net_socket_private;
typedef DoublyLinkedList<net_socket_private> SocketList;
struct net_socket_private : net_socket,
DoublyLinkedListLinkImpl<net_socket_private>,
BWeakReferenceable {
net_socket_private();
~net_socket_private();
void RemoveFromParent();
BWeakReference<net_socket_private> parent;
team_id owner;
uint32 max_backlog;
uint32 child_count;
SocketList pending_children;
SocketList connected_children;
struct select_sync_pool* select_pool;
mutex lock;
bool is_connected;
bool is_in_socket_list;
};
int socket_bind(net_socket* socket, const struct sockaddr* address,
socklen_t addressLength);
int socket_setsockopt(net_socket* socket, int level, int option,
const void* value, int length);
ssize_t socket_read_avail(net_socket* socket);
static SocketList sSocketList;
static mutex sSocketLock;
net_socket_private::net_socket_private()
:
owner(-1),
max_backlog(0),
child_count(0),
select_pool(NULL),
is_connected(false),
is_in_socket_list(false)
{
first_protocol = NULL;
first_info = NULL;
options = 0;
linger = 0;
bound_to_device = 0;
error = 0;
address.ss_len = 0;
peer.ss_len = 0;
mutex_init(&lock, "socket");
// set defaults (may be overridden by the protocols)
send.buffer_size = 65535;
send.low_water_mark = 1;
send.timeout = B_INFINITE_TIMEOUT;
receive.buffer_size = 65535;
receive.low_water_mark = 1;
receive.timeout = B_INFINITE_TIMEOUT;
}
net_socket_private::~net_socket_private()
{
TRACE("delete net_socket %p\n", this);
if (parent != NULL)
panic("socket still has a parent!");
if (is_in_socket_list) {
MutexLocker _(sSocketLock);
sSocketList.Remove(this);
}
mutex_lock(&lock);
// also delete all children of this socket
while (net_socket_private* child = pending_children.RemoveHead()) {
child->RemoveFromParent();
}
while (net_socket_private* child = connected_children.RemoveHead()) {
child->RemoveFromParent();
}
mutex_unlock(&lock);
put_domain_protocols(this);
mutex_destroy(&lock);
}
void
net_socket_private::RemoveFromParent()
{
ASSERT(!is_in_socket_list && parent != NULL);
parent = NULL;
mutex_lock(&sSocketLock);
sSocketList.Add(this);
mutex_unlock(&sSocketLock);
is_in_socket_list = true;
ReleaseReference();
}
// #pragma mark -
static size_t
compute_user_iovec_length(iovec* userVec, uint32 count)
{
size_t length = 0;
for (uint32 i = 0; i < count; i++) {
iovec vec;
if (user_memcpy(&vec, userVec + i, sizeof(iovec)) < B_OK)
return 0;
length += vec.iov_len;
}
return length;
}
static status_t
create_socket(int family, int type, int protocol, net_socket_private** _socket)
{
struct net_socket_private* socket = new(std::nothrow) net_socket_private;
if (socket == NULL)
return B_NO_MEMORY;
status_t status = socket->InitCheck();
if (status != B_OK) {
delete socket;
return status;
}
socket->family = family;
socket->type = type;
socket->protocol = protocol;
status = get_domain_protocols(socket);
if (status != B_OK) {
delete socket;
return status;
}
TRACE("create net_socket %p (%u.%u.%u):\n", socket, socket->family,
socket->type, socket->protocol);
#ifdef TRACE_SOCKET
net_protocol* current = socket->first_protocol;
for (int i = 0; current != NULL; current = current->next, i++)
TRACE(" [%d] %p %s\n", i, current, current->module->info.name);
#endif
*_socket = socket;
return B_OK;
}
static status_t
add_ancillary_data(net_socket* socket, ancillary_data_container* container,
void* data, size_t dataLen)
{
cmsghdr* header = (cmsghdr*)data;
if (dataLen == 0)
return B_OK;
if (socket->first_info->add_ancillary_data == NULL)
return B_NOT_SUPPORTED;
while (true) {
if (header->cmsg_len < CMSG_LEN(0) || header->cmsg_len > dataLen)
return B_BAD_VALUE;
status_t status = socket->first_info->add_ancillary_data(
socket->first_protocol, container, header);
if (status != B_OK)
return status;
if (dataLen <= _ALIGN(header->cmsg_len))
break;
dataLen -= _ALIGN(header->cmsg_len);
header = (cmsghdr*)((uint8*)header + _ALIGN(header->cmsg_len));
}
return B_OK;
}
static status_t
process_ancillary_data(net_socket* socket, ancillary_data_container* container,
msghdr* messageHeader)
{
uint8* dataBuffer = (uint8*)messageHeader->msg_control;
int dataBufferLen = messageHeader->msg_controllen;
if (container == NULL || dataBuffer == NULL) {
messageHeader->msg_controllen = 0;
return B_OK;
}
ancillary_data_header header;
void* data = NULL;
while ((data = next_ancillary_data(container, data, &header)) != NULL) {
if (socket->first_info->process_ancillary_data == NULL)
return B_NOT_SUPPORTED;
ssize_t bytesWritten = socket->first_info->process_ancillary_data(
socket->first_protocol, &header, data, dataBuffer, dataBufferLen);
if (bytesWritten < 0)
return bytesWritten;
dataBuffer += bytesWritten;
dataBufferLen -= bytesWritten;
}
messageHeader->msg_controllen -= dataBufferLen;
return B_OK;
}
static status_t
process_ancillary_data(net_socket* socket,
net_buffer* buffer, msghdr* messageHeader)
{
void *dataBuffer = messageHeader->msg_control;
ssize_t bytesWritten;
if (dataBuffer == NULL) {
messageHeader->msg_controllen = 0;
return B_OK;
}
if (socket->first_info->process_ancillary_data_no_container == NULL)
return B_NOT_SUPPORTED;
bytesWritten = socket->first_info->process_ancillary_data_no_container(
socket->first_protocol, buffer, dataBuffer,
messageHeader->msg_controllen);
if (bytesWritten < 0)
return bytesWritten;
messageHeader->msg_controllen = bytesWritten;
return B_OK;
}
static ssize_t
socket_receive_no_buffer(net_socket* socket, msghdr* header, void* data,
size_t length, int flags)
{
iovec stackVec = { data, length };
iovec* vecs = header ? header->msg_iov : &stackVec;
int vecCount = header ? header->msg_iovlen : 1;
sockaddr* address = header ? (sockaddr*)header->msg_name : NULL;
socklen_t* addressLen = header ? &header->msg_namelen : NULL;
ancillary_data_container* ancillaryData = NULL;
ssize_t bytesRead = socket->first_info->read_data_no_buffer(
socket->first_protocol, vecs, vecCount, &ancillaryData, address,
addressLen);
if (bytesRead < 0)
return bytesRead;
CObjectDeleter<ancillary_data_container> ancillaryDataDeleter(ancillaryData,
&delete_ancillary_data_container);
// process ancillary data
if (header != NULL) {
status_t status = process_ancillary_data(socket, ancillaryData, header);
if (status != B_OK)
return status;
header->msg_flags = 0;
}
return bytesRead;
}
#if ENABLE_DEBUGGER_COMMANDS
static void
print_socket_line(net_socket_private* socket, const char* prefix)
{
BReference<net_socket_private> parent;
if (socket->parent.PrivatePointer() != NULL)
parent = socket->parent.GetReference();
kprintf("%s%p %2d.%2d.%2d %6" B_PRId32 " %p %p %p%s\n", prefix, socket,
socket->family, socket->type, socket->protocol, socket->owner,
socket->first_protocol, socket->first_info, parent.Get(),
parent.Get() != NULL ? socket->is_connected ? " (c)" : " (p)" : "");
}
static int
dump_socket(int argc, char** argv)
{
if (argc < 2) {
kprintf("usage: %s [address]\n", argv[0]);
return 0;
}
net_socket_private* socket = (net_socket_private*)parse_expression(argv[1]);
kprintf("SOCKET %p\n", socket);
kprintf(" family.type.protocol: %d.%d.%d\n",
socket->family, socket->type, socket->protocol);
BReference<net_socket_private> parent;
if (socket->parent.PrivatePointer() != NULL)
parent = socket->parent.GetReference();
kprintf(" parent: %p\n", parent.Get());
kprintf(" first protocol: %p\n", socket->first_protocol);
kprintf(" first module_info: %p\n", socket->first_info);
kprintf(" options: %x\n", socket->options);
kprintf(" linger: %d\n", socket->linger);
kprintf(" bound to device: %" B_PRIu32 "\n", socket->bound_to_device);
kprintf(" owner: %" B_PRId32 "\n", socket->owner);
kprintf(" max backlog: %" B_PRId32 "\n", socket->max_backlog);
kprintf(" is connected: %d\n", socket->is_connected);
kprintf(" child_count: %" B_PRIu32 "\n", socket->child_count);
if (socket->child_count == 0)
return 0;
kprintf(" pending children:\n");
SocketList::Iterator iterator = socket->pending_children.GetIterator();
while (net_socket_private* child = iterator.Next()) {
print_socket_line(child, " ");
}
kprintf(" connected children:\n");
iterator = socket->connected_children.GetIterator();
while (net_socket_private* child = iterator.Next()) {
print_socket_line(child, " ");
}
return 0;
}
static int
dump_sockets(int argc, char** argv)
{
kprintf("address kind owner protocol module_info parent\n");
SocketList::Iterator iterator = sSocketList.GetIterator();
while (net_socket_private* socket = iterator.Next()) {
print_socket_line(socket, "");
SocketList::Iterator childIterator
= socket->pending_children.GetIterator();
while (net_socket_private* child = childIterator.Next()) {
print_socket_line(child, " ");
}
childIterator = socket->connected_children.GetIterator();
while (net_socket_private* child = childIterator.Next()) {
print_socket_line(child, " ");
}
}
return 0;
}
#endif // ENABLE_DEBUGGER_COMMANDS
// #pragma mark -
status_t
socket_open(int family, int type, int protocol, net_socket** _socket)
{
net_socket_private* socket;
status_t status = create_socket(family, type, protocol, &socket);
if (status != B_OK)
return status;
status = socket->first_info->open(socket->first_protocol);
if (status != B_OK) {
delete socket;
return status;
}
socket->owner = team_get_current_team_id();
socket->is_in_socket_list = true;
mutex_lock(&sSocketLock);
sSocketList.Add(socket);
mutex_unlock(&sSocketLock);
*_socket = socket;
return B_OK;
}
status_t
socket_close(net_socket* _socket)
{
net_socket_private* socket = (net_socket_private*)_socket;
return socket->first_info->close(socket->first_protocol);
}
void
socket_free(net_socket* _socket)
{
net_socket_private* socket = (net_socket_private*)_socket;
socket->first_info->free(socket->first_protocol);
socket->ReleaseReference();
}
status_t
socket_readv(net_socket* socket, const iovec* vecs, size_t vecCount,
size_t* _length)
{
return -1;
}
status_t
socket_writev(net_socket* socket, const iovec* vecs, size_t vecCount,
size_t* _length)
{
if (socket->peer.ss_len == 0)
return ECONNRESET;
if (socket->address.ss_len == 0) {
// try to bind first
status_t status = socket_bind(socket, NULL, 0);
if (status != B_OK)
return status;
}
// TODO: useful, maybe even computed header space!
net_buffer* buffer = gNetBufferModule.create(256);
if (buffer == NULL)
return ENOBUFS;
// copy data into buffer
for (uint32 i = 0; i < vecCount; i++) {
if (gNetBufferModule.append(buffer, vecs[i].iov_base,
vecs[i].iov_len) < B_OK) {
gNetBufferModule.free(buffer);
return ENOBUFS;
}
}
memcpy(buffer->source, &socket->address, socket->address.ss_len);
memcpy(buffer->destination, &socket->peer, socket->peer.ss_len);
size_t size = buffer->size;
ssize_t bytesWritten = socket->first_info->send_data(socket->first_protocol,
buffer);
if (bytesWritten < B_OK) {
if (buffer->size != size) {
// this appears to be a partial write
*_length = size - buffer->size;
}
gNetBufferModule.free(buffer);
return bytesWritten;
}
*_length = bytesWritten;
return B_OK;
}
status_t
socket_control(net_socket* socket, uint32 op, void* data, size_t length)
{
switch (op) {
case FIONBIO:
{
if (data == NULL)
return B_BAD_VALUE;
int value;
if (is_syscall()) {
if (!IS_USER_ADDRESS(data)
|| user_memcpy(&value, data, sizeof(int)) != B_OK) {
return B_BAD_ADDRESS;
}
} else
value = *(int*)data;
return socket_setsockopt(socket, SOL_SOCKET, SO_NONBLOCK, &value,
sizeof(int));
}
case FIONREAD:
{
if (data == NULL)
return B_BAD_VALUE;
int available = (int)socket_read_avail(socket);
if (available < 0)
return available;
if (is_syscall()) {
if (!IS_USER_ADDRESS(data)
|| user_memcpy(data, &available, sizeof(available))
!= B_OK) {
return B_BAD_ADDRESS;
}
} else
*(int*)data = available;
return B_OK;
}
case B_SET_BLOCKING_IO:
case B_SET_NONBLOCKING_IO:
{
int value = op == B_SET_NONBLOCKING_IO;
return socket_setsockopt(socket, SOL_SOCKET, SO_NONBLOCK, &value,
sizeof(int));
}
}
return socket->first_info->control(socket->first_protocol,
LEVEL_DRIVER_IOCTL, op, data, &length);
}
ssize_t
socket_read_avail(net_socket* socket)
{
return socket->first_info->read_avail(socket->first_protocol);
}
ssize_t
socket_send_avail(net_socket* socket)
{
return socket->first_info->send_avail(socket->first_protocol);
}
status_t
socket_send_data(net_socket* socket, net_buffer* buffer)
{
return socket->first_info->send_data(socket->first_protocol,
buffer);
}
status_t
socket_receive_data(net_socket* socket, size_t length, uint32 flags,
net_buffer** _buffer)
{
status_t status = socket->first_info->read_data(socket->first_protocol,
length, flags, _buffer);
if (status != B_OK)
return status;
if (*_buffer && length < (*_buffer)->size) {
// discard any data behind the amount requested
gNetBufferModule.trim(*_buffer, length);
}
return status;
}
status_t
socket_get_next_stat(uint32* _cookie, int family, struct net_stat* stat)
{
MutexLocker locker(sSocketLock);
net_socket_private* socket = NULL;
SocketList::Iterator iterator = sSocketList.GetIterator();
uint32 cookie = *_cookie;
uint32 count = 0;
while (true) {
socket = iterator.Next();
if (socket == NULL)
return B_ENTRY_NOT_FOUND;
// TODO: also traverse the pending connections
if (count == cookie)
break;
if (family == -1 || family == socket->family)
count++;
}
*_cookie = count + 1;
stat->family = socket->family;
stat->type = socket->type;
stat->protocol = socket->protocol;
stat->owner = socket->owner;
stat->state[0] = '\0';
memcpy(&stat->address, &socket->address, sizeof(struct sockaddr_storage));
memcpy(&stat->peer, &socket->peer, sizeof(struct sockaddr_storage));
stat->receive_queue_size = 0;
stat->send_queue_size = 0;
// fill in protocol specific data (if supported by the protocol)
size_t length = sizeof(net_stat);
socket->first_info->control(socket->first_protocol, socket->protocol,
NET_STAT_SOCKET, stat, &length);
return B_OK;
}
// #pragma mark - connections
bool
socket_acquire(net_socket* _socket)
{
net_socket_private* socket = (net_socket_private*)_socket;
// During destruction, the socket might still be accessible over its
// endpoint protocol. We need to make sure the endpoint cannot acquire the
// socket anymore -- while not obvious, the endpoint protocol is responsible
// for the proper locking here.
if (socket->CountReferences() == 0)
return false;
socket->AcquireReference();
return true;
}
bool
socket_release(net_socket* _socket)
{
net_socket_private* socket = (net_socket_private*)_socket;
return socket->ReleaseReference();
}
status_t
socket_spawn_pending(net_socket* _parent, net_socket** _socket)
{
net_socket_private* parent = (net_socket_private*)_parent;
TRACE("%s(%p)\n", __FUNCTION__, parent);
MutexLocker locker(parent->lock);
// We actually accept more pending connections to compensate for those
// that never complete, and also make sure at least a single connection
// can always be accepted
if (parent->child_count > 3 * parent->max_backlog / 2)
return ENOBUFS;
net_socket_private* socket;
status_t status = create_socket(parent->family, parent->type,
parent->protocol, &socket);
if (status != B_OK)
return status;
// inherit parent's properties
socket->send = parent->send;
socket->receive = parent->receive;
socket->options = parent->options & ~SO_ACCEPTCONN;
socket->linger = parent->linger;
socket->owner = parent->owner;
memcpy(&socket->address, &parent->address, parent->address.ss_len);
memcpy(&socket->peer, &parent->peer, parent->peer.ss_len);
// add to the parent's list of pending connections
parent->pending_children.Add(socket);
socket->parent = parent;
parent->child_count++;
*_socket = socket;
return B_OK;
}
/*! Dequeues a connected child from a parent socket.
It also returns a reference with the child socket.
*/
status_t
socket_dequeue_connected(net_socket* _parent, net_socket** _socket)
{
net_socket_private* parent = (net_socket_private*)_parent;
mutex_lock(&parent->lock);
net_socket_private* socket = parent->connected_children.RemoveHead();
if (socket != NULL) {
socket->AcquireReference();
socket->RemoveFromParent();
parent->child_count--;
*_socket = socket;
}
mutex_unlock(&parent->lock);
if (socket == NULL)
return B_ENTRY_NOT_FOUND;
return B_OK;
}
ssize_t
socket_count_connected(net_socket* _parent)
{
net_socket_private* parent = (net_socket_private*)_parent;
MutexLocker _(parent->lock);
return parent->connected_children.Count();
}
status_t
socket_set_max_backlog(net_socket* _socket, uint32 backlog)
{
net_socket_private* socket = (net_socket_private*)_socket;
// we enforce an upper limit of connections waiting to be accepted
if (backlog > 256)
backlog = 256;
MutexLocker _(socket->lock);
// first remove the pending connections, then the already connected
// ones as needed
net_socket_private* child;
while (socket->child_count > backlog
&& (child = socket->pending_children.RemoveTail()) != NULL) {
child->RemoveFromParent();
socket->child_count--;
}
while (socket->child_count > backlog
&& (child = socket->connected_children.RemoveTail()) != NULL) {
child->RemoveFromParent();
socket->child_count--;
}
socket->max_backlog = backlog;
return B_OK;
}
/*! Returns whether or not this socket has a parent. The parent might not be
valid anymore, though.
*/
bool
socket_has_parent(net_socket* _socket)
{
net_socket_private* socket = (net_socket_private*)_socket;
return socket->parent != NULL;
}
/*! The socket has been connected. It will be moved to the connected queue
of its parent socket.
*/
status_t
socket_connected(net_socket* _socket)
{
net_socket_private* socket = (net_socket_private*)_socket;
TRACE("socket_connected(%p)\n", socket);
BReference<net_socket_private> parent = socket->parent.GetReference();
if (parent.Get() == NULL)
return B_BAD_VALUE;
MutexLocker _(parent->lock);
parent->pending_children.Remove(socket);
parent->connected_children.Add(socket);
socket->is_connected = true;
// notify parent
if (parent->select_pool)
notify_select_event_pool(parent->select_pool, B_SELECT_READ);
return B_OK;
}
/*! The socket has been aborted. Steals the parent's reference, and releases
it.
*/
status_t
socket_aborted(net_socket* _socket)
{
net_socket_private* socket = (net_socket_private*)_socket;
TRACE("socket_aborted(%p)\n", socket);
BReference<net_socket_private> parent = socket->parent.GetReference();
if (parent.Get() == NULL)
return B_BAD_VALUE;
MutexLocker _(parent->lock);
if (socket->is_connected)
parent->connected_children.Remove(socket);
else
parent->pending_children.Remove(socket);
parent->child_count--;
socket->RemoveFromParent();
return B_OK;
}
// #pragma mark - notifications
status_t
socket_request_notification(net_socket* _socket, uint8 event, selectsync* sync)
{
net_socket_private* socket = (net_socket_private*)_socket;
mutex_lock(&socket->lock);
status_t status = add_select_sync_pool_entry(&socket->select_pool, sync,
event);
mutex_unlock(&socket->lock);
if (status != B_OK)
return status;
// check if the event is already present
// TODO: add support for poll() types
switch (event) {
case B_SELECT_READ:
{
ssize_t available = socket_read_avail(socket);
if ((ssize_t)socket->receive.low_water_mark <= available
|| available < B_OK)
notify_select_event(sync, event);
break;
}
case B_SELECT_WRITE:
{
ssize_t available = socket_send_avail(socket);
if ((ssize_t)socket->send.low_water_mark <= available
|| available < B_OK)
notify_select_event(sync, event);
break;
}
case B_SELECT_ERROR:
if (socket->error != B_OK)
notify_select_event(sync, event);
break;
}
return B_OK;
}
status_t
socket_cancel_notification(net_socket* _socket, uint8 event, selectsync* sync)
{
net_socket_private* socket = (net_socket_private*)_socket;
MutexLocker _(socket->lock);
return remove_select_sync_pool_entry(&socket->select_pool, sync, event);
}
status_t
socket_notify(net_socket* _socket, uint8 event, int32 value)
{
net_socket_private* socket = (net_socket_private*)_socket;
bool notify = true;
switch (event) {
case B_SELECT_READ:
if ((ssize_t)socket->receive.low_water_mark > value
&& value >= B_OK)
notify = false;
break;
case B_SELECT_WRITE:
if ((ssize_t)socket->send.low_water_mark > value && value >= B_OK)
notify = false;
break;
case B_SELECT_ERROR:
socket->error = value;
break;
}
MutexLocker _(socket->lock);
if (notify && socket->select_pool != NULL) {
notify_select_event_pool(socket->select_pool, event);
if (event == B_SELECT_ERROR) {
// always notify read/write on error
notify_select_event_pool(socket->select_pool, B_SELECT_READ);
notify_select_event_pool(socket->select_pool, B_SELECT_WRITE);
}
}
return B_OK;
}
// #pragma mark - standard socket API
int
socket_accept(net_socket* socket, struct sockaddr* address,
socklen_t* _addressLength, net_socket** _acceptedSocket)
{
if ((socket->options & SO_ACCEPTCONN) == 0)
return B_BAD_VALUE;
net_socket* accepted;
status_t status = socket->first_info->accept(socket->first_protocol,
&accepted);
if (status != B_OK)
return status;
if (address && *_addressLength > 0) {
memcpy(address, &accepted->peer, min_c(*_addressLength,
min_c(accepted->peer.ss_len, sizeof(sockaddr_storage))));
*_addressLength = accepted->peer.ss_len;
}
*_acceptedSocket = accepted;
return B_OK;
}
int
socket_bind(net_socket* socket, const struct sockaddr* address,
socklen_t addressLength)
{
sockaddr empty;
if (address == NULL) {
// special - try to bind to an empty address, like INADDR_ANY
memset(&empty, 0, sizeof(sockaddr));
empty.sa_len = sizeof(sockaddr);
empty.sa_family = socket->family;
address = ∅
addressLength = sizeof(sockaddr);
}
if (socket->address.ss_len != 0) {
status_t status = socket->first_info->unbind(socket->first_protocol,
(sockaddr*)&socket->address);
if (status != B_OK)
return status;
}
memcpy(&socket->address, address, sizeof(sockaddr));
socket->address.ss_len = sizeof(sockaddr_storage);
status_t status = socket->first_info->bind(socket->first_protocol,
(sockaddr*)address);
if (status != B_OK) {
// clear address again, as binding failed
socket->address.ss_len = 0;
}
return status;
}
int
socket_connect(net_socket* socket, const struct sockaddr* address,
socklen_t addressLength)
{
if (address == NULL || addressLength == 0)
return ENETUNREACH;
if (socket->address.ss_len == 0) {
// try to bind first
status_t status = socket_bind(socket, NULL, 0);
if (status != B_OK)
return status;
}
return socket->first_info->connect(socket->first_protocol, address);
}
int
socket_getpeername(net_socket* socket, struct sockaddr* address,
socklen_t* _addressLength)
{
if (socket->peer.ss_len == 0)
return ENOTCONN;
memcpy(address, &socket->peer, min_c(*_addressLength, socket->peer.ss_len));
*_addressLength = socket->peer.ss_len;
return B_OK;
}
int
socket_getsockname(net_socket* socket, struct sockaddr* address,
socklen_t* _addressLength)
{
if (socket->address.ss_len == 0) {
struct sockaddr buffer;
memset(&buffer, 0, sizeof(buffer));
buffer.sa_family = socket->family;
memcpy(address, &buffer, min_c(*_addressLength, sizeof(buffer)));
*_addressLength = sizeof(buffer);
return B_OK;
}
memcpy(address, &socket->address, min_c(*_addressLength,
socket->address.ss_len));
*_addressLength = socket->address.ss_len;
return B_OK;
}
status_t
socket_get_option(net_socket* socket, int level, int option, void* value,
int* _length)
{
if (level != SOL_SOCKET)
return ENOPROTOOPT;
switch (option) {
case SO_SNDBUF:
{
uint32* size = (uint32*)value;
*size = socket->send.buffer_size;
*_length = sizeof(uint32);
return B_OK;
}
case SO_RCVBUF:
{
uint32* size = (uint32*)value;
*size = socket->receive.buffer_size;
*_length = sizeof(uint32);
return B_OK;
}
case SO_SNDLOWAT:
{
uint32* size = (uint32*)value;
*size = socket->send.low_water_mark;
*_length = sizeof(uint32);
return B_OK;
}
case SO_RCVLOWAT:
{
uint32* size = (uint32*)value;
*size = socket->receive.low_water_mark;
*_length = sizeof(uint32);
return B_OK;
}
case SO_RCVTIMEO:
case SO_SNDTIMEO:
{
if (*_length < (int)sizeof(struct timeval))
return B_BAD_VALUE;
bigtime_t timeout;
if (option == SO_SNDTIMEO)
timeout = socket->send.timeout;
else
timeout = socket->receive.timeout;
if (timeout == B_INFINITE_TIMEOUT)
timeout = 0;
struct timeval* timeval = (struct timeval*)value;
timeval->tv_sec = timeout / 1000000LL;
timeval->tv_usec = timeout % 1000000LL;
*_length = sizeof(struct timeval);
return B_OK;
}
case SO_NONBLOCK:
{
int32* _set = (int32*)value;
*_set = socket->receive.timeout == 0 && socket->send.timeout == 0;
*_length = sizeof(int32);
return B_OK;
}
case SO_ACCEPTCONN:
case SO_BROADCAST:
case SO_DEBUG:
case SO_DONTROUTE:
case SO_KEEPALIVE:
case SO_OOBINLINE:
case SO_REUSEADDR:
case SO_REUSEPORT:
case SO_USELOOPBACK:
{
int32* _set = (int32*)value;
*_set = (socket->options & option) != 0;
*_length = sizeof(int32);
return B_OK;
}
case SO_TYPE:
{
int32* _set = (int32*)value;
*_set = socket->type;
*_length = sizeof(int32);
return B_OK;
}
case SO_ERROR:
{
int32* _set = (int32*)value;
*_set = socket->error;
*_length = sizeof(int32);
socket->error = B_OK;
// clear error upon retrieval
return B_OK;
}
default:
break;
}
dprintf("socket_getsockopt: unknown option %d\n", option);
return ENOPROTOOPT;
}
int
socket_getsockopt(net_socket* socket, int level, int option, void* value,
int* _length)
{
return socket->first_protocol->module->getsockopt(socket->first_protocol,
level, option, value, _length);
}
int
socket_listen(net_socket* socket, int backlog)
{
status_t status = socket->first_info->listen(socket->first_protocol,
backlog);
if (status == B_OK)
socket->options |= SO_ACCEPTCONN;
return status;
}
ssize_t
socket_receive(net_socket* socket, msghdr* header, void* data, size_t length,
int flags)
{
// If the protocol sports read_data_no_buffer() we use it.
if (socket->first_info->read_data_no_buffer != NULL)
return socket_receive_no_buffer(socket, header, data, length, flags);
size_t totalLength = length;
net_buffer* buffer;
int i;
// the convention to this function is that have header been
// present, { data, length } would have been iovec[0] and is
// always considered like that
if (header) {
// calculate the length considering all of the extra buffers
for (i = 1; i < header->msg_iovlen; i++)
totalLength += header->msg_iov[i].iov_len;
}
status_t status = socket->first_info->read_data(
socket->first_protocol, totalLength, flags, &buffer);
if (status != B_OK)
return status;
// process ancillary data
if (header != NULL) {
if (buffer != NULL && header->msg_control != NULL) {
ancillary_data_container* container
= gNetBufferModule.get_ancillary_data(buffer);
if (container != NULL)
status = process_ancillary_data(socket, container, header);
else
status = process_ancillary_data(socket, buffer, header);
if (status != B_OK) {
gNetBufferModule.free(buffer);
return status;
}
} else
header->msg_controllen = 0;
}
// TODO: - returning a NULL buffer when received 0 bytes
// may not make much sense as we still need the address
// - gNetBufferModule.read() uses memcpy() instead of user_memcpy
size_t nameLen = 0;
if (header) {
// TODO: - consider the control buffer options
nameLen = header->msg_namelen;
header->msg_namelen = 0;
header->msg_flags = 0;
}
if (buffer == NULL)
return 0;
size_t bytesReceived = buffer->size, bytesCopied = 0;
length = min_c(bytesReceived, length);
if (gNetBufferModule.read(buffer, 0, data, length) < B_OK) {
gNetBufferModule.free(buffer);
return ENOBUFS;
}
// if first copy was a success, proceed to following
// copies as required
bytesCopied += length;
if (header) {
// we only start considering at iovec[1]
// as { data, length } is iovec[0]
for (i = 1; i < header->msg_iovlen && bytesCopied < bytesReceived; i++) {
iovec& vec = header->msg_iov[i];
size_t toRead = min_c(bytesReceived - bytesCopied, vec.iov_len);
if (gNetBufferModule.read(buffer, bytesCopied, vec.iov_base,
toRead) < B_OK) {
break;
}
bytesCopied += toRead;
}
if (header->msg_name != NULL) {
header->msg_namelen = min_c(nameLen, buffer->source->sa_len);
memcpy(header->msg_name, buffer->source, header->msg_namelen);
}
}
gNetBufferModule.free(buffer);
if (bytesCopied < bytesReceived) {
if (header)
header->msg_flags = MSG_TRUNC;
if (flags & MSG_TRUNC)
return bytesReceived;
}
return bytesCopied;
}
ssize_t
socket_send(net_socket* socket, msghdr* header, const void* data, size_t length,
int flags)
{
const sockaddr* address = NULL;
socklen_t addressLength = 0;
size_t bytesLeft = length;
if (length > SSIZE_MAX)
return B_BAD_VALUE;
ancillary_data_container* ancillaryData = NULL;
CObjectDeleter<ancillary_data_container> ancillaryDataDeleter(NULL,
&delete_ancillary_data_container);
if (header != NULL) {
address = (const sockaddr*)header->msg_name;
addressLength = header->msg_namelen;
// get the ancillary data
if (header->msg_control != NULL) {
ancillaryData = create_ancillary_data_container();
if (ancillaryData == NULL)
return B_NO_MEMORY;
ancillaryDataDeleter.SetTo(ancillaryData);
status_t status = add_ancillary_data(socket, ancillaryData,
(cmsghdr*)header->msg_control, header->msg_controllen);
if (status != B_OK)
return status;
}
}
if (addressLength == 0)
address = NULL;
else if (address == NULL)
return B_BAD_VALUE;
if (socket->peer.ss_len != 0) {
if (address != NULL)
return EISCONN;
// socket is connected, we use that address
address = (struct sockaddr*)&socket->peer;
addressLength = socket->peer.ss_len;
}
if (address == NULL || addressLength == 0) {
// don't know where to send to:
return EDESTADDRREQ;
}
if ((socket->first_info->flags & NET_PROTOCOL_ATOMIC_MESSAGES) != 0
&& bytesLeft > socket->send.buffer_size)
return EMSGSIZE;
if (socket->address.ss_len == 0) {
// try to bind first
status_t status = socket_bind(socket, NULL, 0);
if (status != B_OK)
return status;
}
// If the protocol has a send_data_no_buffer() hook, we use that one.
if (socket->first_info->send_data_no_buffer != NULL) {
iovec stackVec = { (void*)data, length };
iovec* vecs = header ? header->msg_iov : &stackVec;
int vecCount = header ? header->msg_iovlen : 1;
ssize_t written = socket->first_info->send_data_no_buffer(
socket->first_protocol, vecs, vecCount, ancillaryData, address,
addressLength);
if (written > 0)
ancillaryDataDeleter.Detach();
return written;
}
// By convention, if a header is given, the (data, length) equals the first
// iovec. So drop the header, if it is the only iovec. Otherwise compute
// the size of the remaining ones.
if (header != NULL) {
if (header->msg_iovlen <= 1)
header = NULL;
else {
// TODO: The iovecs have already been copied to kernel space. Simplify!
bytesLeft += compute_user_iovec_length(header->msg_iov + 1,
header->msg_iovlen - 1);
}
}
ssize_t bytesSent = 0;
size_t vecOffset = 0;
uint32 vecIndex = 0;
while (bytesLeft > 0) {
// TODO: useful, maybe even computed header space!
net_buffer* buffer = gNetBufferModule.create(256);
if (buffer == NULL)
return ENOBUFS;
while (buffer->size < socket->send.buffer_size
&& buffer->size < bytesLeft) {
if (vecIndex > 0 && vecOffset == 0) {
// retrieve next iovec buffer from header
iovec vec;
if (user_memcpy(&vec, header->msg_iov + vecIndex, sizeof(iovec))
< B_OK) {
gNetBufferModule.free(buffer);
return B_BAD_ADDRESS;
}
data = vec.iov_base;
length = vec.iov_len;
}
size_t bytes = length;
if (buffer->size + bytes > socket->send.buffer_size)
bytes = socket->send.buffer_size - buffer->size;
if (gNetBufferModule.append(buffer, data, bytes) < B_OK) {
gNetBufferModule.free(buffer);
return ENOBUFS;
}
if (bytes != length) {
// partial send
vecOffset = bytes;
length -= vecOffset;
data = (uint8*)data + vecOffset;
} else if (header != NULL) {
// proceed with next buffer, if any
vecOffset = 0;
vecIndex++;
if (vecIndex >= (uint32)header->msg_iovlen)
break;
}
}
// attach ancillary data to the first buffer
status_t status = B_OK;
if (ancillaryData != NULL) {
gNetBufferModule.set_ancillary_data(buffer, ancillaryData);
ancillaryDataDeleter.Detach();
ancillaryData = NULL;
}
size_t bufferSize = buffer->size;
buffer->flags = flags;
memcpy(buffer->source, &socket->address, socket->address.ss_len);
memcpy(buffer->destination, address, addressLength);
buffer->destination->sa_len = addressLength;
if (status == B_OK) {
status = socket->first_info->send_data(socket->first_protocol,
buffer);
}
if (status != B_OK) {
size_t sizeAfterSend = buffer->size;
gNetBufferModule.free(buffer);
if ((sizeAfterSend != bufferSize || bytesSent > 0)
&& (status == B_INTERRUPTED || status == B_WOULD_BLOCK)) {
// this appears to be a partial write
return bytesSent + (bufferSize - sizeAfterSend);
}
return status;
}
bytesLeft -= bufferSize;
bytesSent += bufferSize;
}
return bytesSent;
}
status_t
socket_set_option(net_socket* socket, int level, int option, const void* value,
int length)
{
if (level != SOL_SOCKET)
return ENOPROTOOPT;
TRACE("%s(socket %p, option %d\n", __FUNCTION__, socket, option);
switch (option) {
// TODO: implement other options!
case SO_LINGER:
{
if (length < (int)sizeof(struct linger))
return B_BAD_VALUE;
struct linger* linger = (struct linger*)value;
if (linger->l_onoff) {
socket->options |= SO_LINGER;
socket->linger = linger->l_linger;
} else {
socket->options &= ~SO_LINGER;
socket->linger = 0;
}
return B_OK;
}
case SO_SNDBUF:
if (length != sizeof(uint32))
return B_BAD_VALUE;
socket->send.buffer_size = *(const uint32*)value;
return B_OK;
case SO_RCVBUF:
if (length != sizeof(uint32))
return B_BAD_VALUE;
socket->receive.buffer_size = *(const uint32*)value;
return B_OK;
case SO_SNDLOWAT:
if (length != sizeof(uint32))
return B_BAD_VALUE;
socket->send.low_water_mark = *(const uint32*)value;
return B_OK;
case SO_RCVLOWAT:
if (length != sizeof(uint32))
return B_BAD_VALUE;
socket->receive.low_water_mark = *(const uint32*)value;
return B_OK;
case SO_RCVTIMEO:
case SO_SNDTIMEO:
{
if (length != sizeof(struct timeval))
return B_BAD_VALUE;
const struct timeval* timeval = (const struct timeval*)value;
bigtime_t timeout = timeval->tv_sec * 1000000LL + timeval->tv_usec;
if (timeout == 0)
timeout = B_INFINITE_TIMEOUT;
if (option == SO_SNDTIMEO)
socket->send.timeout = timeout;
else
socket->receive.timeout = timeout;
return B_OK;
}
case SO_NONBLOCK:
if (length != sizeof(int32))
return B_BAD_VALUE;
if (*(const int32*)value) {
socket->send.timeout = 0;
socket->receive.timeout = 0;
} else {
socket->send.timeout = B_INFINITE_TIMEOUT;
socket->receive.timeout = B_INFINITE_TIMEOUT;
}
return B_OK;
case SO_BROADCAST:
case SO_DEBUG:
case SO_DONTROUTE:
case SO_KEEPALIVE:
case SO_OOBINLINE:
case SO_REUSEADDR:
case SO_REUSEPORT:
case SO_USELOOPBACK:
if (length != sizeof(int32))
return B_BAD_VALUE;
if (*(const int32*)value)
socket->options |= option;
else
socket->options &= ~option;
return B_OK;
case SO_BINDTODEVICE:
{
if (length != sizeof(uint32))
return B_BAD_VALUE;
// TODO: we might want to check if the device exists at all
// (although it doesn't really harm when we don't)
socket->bound_to_device = *(const uint32*)value;
return B_OK;
}
default:
break;
}
dprintf("socket_setsockopt: unknown option %d\n", option);
return ENOPROTOOPT;
}
int
socket_setsockopt(net_socket* socket, int level, int option, const void* value,
int length)
{
return socket->first_protocol->module->setsockopt(socket->first_protocol,
level, option, value, length);
}
int
socket_shutdown(net_socket* socket, int direction)
{
return socket->first_info->shutdown(socket->first_protocol, direction);
}
status_t
socket_socketpair(int family, int type, int protocol, net_socket* sockets[2])
{
sockets[0] = NULL;
sockets[1] = NULL;
// create sockets
status_t error = socket_open(family, type, protocol, &sockets[0]);
if (error != B_OK)
return error;
if (error == B_OK)
error = socket_open(family, type, protocol, &sockets[1]);
// bind one
if (error == B_OK)
error = socket_bind(sockets[0], NULL, 0);
// start listening
if (error == B_OK)
error = socket_listen(sockets[0], 1);
// connect them
if (error == B_OK) {
error = socket_connect(sockets[1], (sockaddr*)&sockets[0]->address,
sockets[0]->address.ss_len);
}
// accept a socket
net_socket* acceptedSocket = NULL;
if (error == B_OK)
error = socket_accept(sockets[0], NULL, NULL, &acceptedSocket);
if (error == B_OK) {
// everything worked: close the listener socket
socket_close(sockets[0]);
socket_free(sockets[0]);
sockets[0] = acceptedSocket;
} else {
// close sockets on error
for (int i = 0; i < 2; i++) {
if (sockets[i] != NULL) {
socket_close(sockets[i]);
socket_free(sockets[i]);
sockets[i] = NULL;
}
}
}
return error;
}
// #pragma mark -
static status_t
socket_std_ops(int32 op, ...)
{
switch (op) {
case B_MODULE_INIT:
{
new (&sSocketList) SocketList;
mutex_init(&sSocketLock, "socket list");
#if ENABLE_DEBUGGER_COMMANDS
add_debugger_command("sockets", dump_sockets, "lists all sockets");
add_debugger_command("socket", dump_socket, "dumps a socket");
#endif
return B_OK;
}
case B_MODULE_UNINIT:
ASSERT(sSocketList.IsEmpty());
mutex_destroy(&sSocketLock);
#if ENABLE_DEBUGGER_COMMANDS
remove_debugger_command("socket", dump_socket);
remove_debugger_command("sockets", dump_sockets);
#endif
return B_OK;
default:
return B_ERROR;
}
}
net_socket_module_info gNetSocketModule = {
{
NET_SOCKET_MODULE_NAME,
0,
socket_std_ops
},
socket_open,
socket_close,
socket_free,
socket_readv,
socket_writev,
socket_control,
socket_read_avail,
socket_send_avail,
socket_send_data,
socket_receive_data,
socket_get_option,
socket_set_option,
socket_get_next_stat,
// connections
socket_acquire,
socket_release,
socket_spawn_pending,
socket_dequeue_connected,
socket_count_connected,
socket_set_max_backlog,
socket_has_parent,
socket_connected,
socket_aborted,
// notifications
socket_request_notification,
socket_cancel_notification,
socket_notify,
// standard socket API
socket_accept,
socket_bind,
socket_connect,
socket_getpeername,
socket_getsockname,
socket_getsockopt,
socket_listen,
socket_receive,
socket_send,
socket_setsockopt,
socket_shutdown,
socket_socketpair
};
↑ V547 Expression 'error == ((int) 0)' is always true.
↑ V547 Expression 'status == ((int) 0)' is always true.