/*
 * Copyright 2006-2010, Haiku, Inc. All Rights Reserved.
 * Distributed under the terms of the MIT License.
 *
 * Authors:
 *		Axel Dörfler, axeld@pinc-software.de
 */
 
 
#include "domains.h"
#include "interfaces.h"
#include "routes.h"
#include "stack_private.h"
#include "utility.h"
 
#include <net_device.h>
#include <NetUtilities.h>
 
#include <lock.h>
#include <util/AutoLock.h>
 
#include <KernelExport.h>
 
#include <net/if_dl.h>
#include <net/route.h>
#include <new>
#include <stdlib.h>
#include <string.h>
#include <sys/sockio.h>
 
 
//#define TRACE_ROUTES
#ifdef TRACE_ROUTES
#	define TRACE(x...) dprintf(STACK_DEBUG_PREFIX x)
#else
#	define TRACE(x...) ;
#endif
 
 
net_route_private::net_route_private()
{
	destination = mask = gateway = NULL;
}
 
 
net_route_private::~net_route_private()
{
	free(destination);
	free(mask);
	free(gateway);
}
 
 
//	#pragma mark - private functions
 
 
static status_t
user_copy_address(const sockaddr* from, sockaddr** to)
{
	if (from == NULL) {
		*to = NULL;
		return B_OK;
	}
 
	sockaddr address;
	if (user_memcpy(&address, from, sizeof(struct sockaddr)) < B_OK)
		return B_BAD_ADDRESS;
 
	*to = (sockaddr*)malloc(address.sa_len);
	if (*to == NULL)
		return B_NO_MEMORY;
 
	if (address.sa_len > sizeof(struct sockaddr)) {
		if (user_memcpy(*to, from, address.sa_len) < B_OK)
			return B_BAD_ADDRESS;
	} else
		memcpy(*to, &address, address.sa_len);
 
	return B_OK;
}
 
 
static status_t
user_copy_address(const sockaddr* from, sockaddr_storage* to)
{
	if (from == NULL)
		return B_BAD_ADDRESS;
 
	if (user_memcpy(to, from, sizeof(sockaddr)) < B_OK)
		return B_BAD_ADDRESS;
 
	if (to->ss_len > sizeof(sockaddr)) {
		if (to->ss_len > sizeof(sockaddr_storage))
			return B_BAD_VALUE;
		if (user_memcpy(to, from, to->ss_len) < B_OK)
			return B_BAD_ADDRESS;
	}
 
	return B_OK;
}
 
 
static net_route_private*
find_route(struct net_domain* _domain, const net_route* description)
{
	struct net_domain_private* domain = (net_domain_private*)_domain;
	RouteList::Iterator iterator = domain->routes.GetIterator();
 
	while (iterator.HasNext()) {
		net_route_private* route = iterator.Next();
 
		if ((route->flags & RTF_DEFAULT) != 0
			&& (description->flags & RTF_DEFAULT) != 0) {
			// there can only be one default route per interface address family
			// TODO: check this better
			if (route->interface_address == description->interface_address)
				return route;
 
			continue;
		}
 
		if ((route->flags & (RTF_GATEWAY | RTF_HOST | RTF_LOCAL | RTF_DEFAULT))
				== (description->flags
					& (RTF_GATEWAY | RTF_HOST | RTF_LOCAL | RTF_DEFAULT))
			&& domain->address_module->equal_masked_addresses(
				route->destination, description->destination, description->mask)
			&& domain->address_module->equal_addresses(route->mask,
				description->mask)
			&& domain->address_module->equal_addresses(route->gateway,
				description->gateway)
			&& (description->interface_address == NULL
				|| description->interface_address == route->interface_address))
			return route;
	}
 
	return NULL;
}
 
 
static net_route_private*
find_route(net_domain* _domain, const sockaddr* address)
{
	net_domain_private* domain = (net_domain_private*)_domain;
 
	// find last matching route
 
	RouteList::Iterator iterator = domain->routes.GetIterator();
	net_route_private* candidate = NULL;
 
	TRACE("test address %s for routes...\n",
		AddressString(domain, address).Data());
 
	// TODO: alternate equal default routes
 
	while (iterator.HasNext()) {
		net_route_private* route = iterator.Next();
 
		if (route->mask) {
			sockaddr maskedAddress;
			domain->address_module->mask_address(address, route->mask,
				&maskedAddress);
			if (!domain->address_module->equal_addresses(&maskedAddress,
					route->destination))
				continue;
		} else if (!domain->address_module->equal_addresses(address,
				route->destination))
			continue;
 
		// neglect routes that point to devices that have no link
		if ((route->interface_address->interface->device->flags & IFF_LINK)
				== 0) {
			if (candidate == NULL) {
				TRACE("  found candidate: %s, flags %lx\n", AddressString(
					domain, route->destination).Data(), route->flags);
				candidate = route;
			}
			continue;
		}
 
		TRACE("  found route: %s, flags %lx\n",
			AddressString(domain, route->destination).Data(), route->flags);
 
		return route;
	}
 
	return candidate;
}
 
 
static void
put_route_internal(struct net_domain_private* domain, net_route* _route)
{
	ASSERT_LOCKED_RECURSIVE(&domain->lock);
 
	net_route_private* route = (net_route_private*)_route;
	if (route == NULL || atomic_add(&route->ref_count, -1) != 1)
		return;
 
	// delete route - it must already have been removed at this point
	if (route->interface_address != NULL)
		((InterfaceAddress*)route->interface_address)->ReleaseReference();
 
	delete route;
}
 
 
static struct net_route*
get_route_internal(struct net_domain_private* domain,
	const struct sockaddr* address)
{
	ASSERT_LOCKED_RECURSIVE(&domain->lock);
	net_route_private* route = NULL;
 
	if (address->sa_family == AF_LINK) {
		// special address to find an interface directly
		RouteList::Iterator iterator = domain->routes.GetIterator();
		const sockaddr_dl* link = (const sockaddr_dl*)address;
 
		while (iterator.HasNext()) {
			route = iterator.Next();
 
			net_device* device = route->interface_address->interface->device;
 
			if ((link->sdl_nlen > 0
					&& !strncmp(device->name, (const char*)link->sdl_data,
							IF_NAMESIZE))
				|| (link->sdl_nlen == 0 && link->sdl_alen > 0
					&& !memcmp(LLADDR(link), device->address.data,
							device->address.length)))
				break;
		}
	} else
		route = find_route(domain, address);
 
	if (route != NULL && atomic_add(&route->ref_count, 1) == 0) {
		// route has been deleted already
		route = NULL;
	}
 
	return route;
}
 
 
static void
update_route_infos(struct net_domain_private* domain)
{
	ASSERT_LOCKED_RECURSIVE(&domain->lock);
	RouteInfoList::Iterator iterator = domain->route_infos.GetIterator();
 
	while (iterator.HasNext()) {
		net_route_info* info = iterator.Next();
 
		put_route_internal(domain, info->route);
		info->route = get_route_internal(domain, &info->address);
	}
}
 
 
static sockaddr*
copy_address(UserBuffer& buffer, sockaddr* address)
{
	if (address == NULL)
		return NULL;
 
	return (sockaddr*)buffer.Push(address, address->sa_len);
}
 
 
static status_t
fill_route_entry(route_entry* target, void* _buffer, size_t bufferSize,
	net_route* route)
{
	UserBuffer buffer(((uint8*)_buffer) + sizeof(route_entry),
		bufferSize - sizeof(route_entry));
 
	target->destination = copy_address(buffer, route->destination);
	target->mask = copy_address(buffer, route->mask);
	target->gateway = copy_address(buffer, route->gateway);
	target->source = copy_address(buffer, route->interface_address->local);
	target->flags = route->flags;
	target->mtu = route->mtu;
 
	return buffer.Status();
}
 
 
//	#pragma mark - exported functions
 
 
/*!	Determines the size of a buffer large enough to contain the whole
	routing table.
*/
uint32
route_table_size(net_domain_private* domain)
{
	RecursiveLocker locker(domain->lock);
	uint32 size = 0;
 
	RouteList::Iterator iterator = domain->routes.GetIterator();
	while (iterator.HasNext()) {
		net_route_private* route = iterator.Next();
		size += IF_NAMESIZE + sizeof(route_entry);
 
		if (route->destination)
			size += route->destination->sa_len;
		if (route->mask)
			size += route->mask->sa_len;
		if (route->gateway)
			size += route->gateway->sa_len;
	}
 
	return size;
}
 
 
/*!	Dumps a list of all routes into the supplied userland buffer.
	If the routes don't fit into the buffer, an error (\c ENOBUFS) is
	returned.
*/
status_t
list_routes(net_domain_private* domain, void* buffer, size_t size)
{
	RecursiveLocker _(domain->lock);
 
	RouteList::Iterator iterator = domain->routes.GetIterator();
	const size_t kBaseSize = IF_NAMESIZE + sizeof(route_entry);
	size_t spaceLeft = size;
 
	sockaddr zeros;
	memset(&zeros, 0, sizeof(sockaddr));
	zeros.sa_family = domain->family;
	zeros.sa_len = sizeof(sockaddr);
 
	while (iterator.HasNext()) {
		net_route* route = iterator.Next();
 
		size = kBaseSize;
 
		sockaddr* destination = NULL;
		sockaddr* mask = NULL;
		sockaddr* gateway = NULL;
		uint8* next = (uint8*)buffer + size;
 
		if (route->destination != NULL) {
			destination = (sockaddr*)next;
			next += route->destination->sa_len;
			size += route->destination->sa_len;
		}
		if (route->mask != NULL) {
			mask = (sockaddr*)next;
			next += route->mask->sa_len;
			size += route->mask->sa_len;
		}
		if (route->gateway != NULL) {
			gateway = (sockaddr*)next;
			next += route->gateway->sa_len;
			size += route->gateway->sa_len;
		}
 
		if (spaceLeft < size)
			return ENOBUFS;
 
		ifreq request;
		memset(&request, 0, sizeof(request));
 
		strlcpy(request.ifr_name, route->interface_address->interface->name,
			IF_NAMESIZE);
		request.ifr_route.destination = destination;
		request.ifr_route.mask = mask;
		request.ifr_route.gateway = gateway;
		request.ifr_route.mtu = route->mtu;
		request.ifr_route.flags = route->flags;
 
		// copy data into userland buffer
		if (user_memcpy(buffer, &request, kBaseSize) < B_OK
			|| (route->destination != NULL
				&& user_memcpy(request.ifr_route.destination,
					route->destination, route->destination->sa_len) < B_OK)
			|| (route->mask != NULL && user_memcpy(request.ifr_route.mask,
					route->mask, route->mask->sa_len) < B_OK)
			|| (route->gateway != NULL && user_memcpy(request.ifr_route.gateway,
					route->gateway, route->gateway->sa_len) < B_OK))
			return B_BAD_ADDRESS;
 
		buffer = (void*)next;
		spaceLeft -= size;
	}
 
	return B_OK;
}
 
 
status_t
control_routes(struct net_interface* _interface, net_domain* domain,
	int32 option, void* argument, size_t length)
{
	TRACE("control_routes(interface %p, domain %p, option %" B_PRId32 ")\n",
		_interface, domain, option);
	Interface* interface = (Interface*)_interface;
 
	switch (option) {
		case SIOCADDRT:
		case SIOCDELRT:
		{
			// add or remove a route
			if (length != sizeof(struct ifreq))
				return B_BAD_VALUE;
 
			route_entry entry;
			if (user_memcpy(&entry, &((ifreq*)argument)->ifr_route,
					sizeof(route_entry)) != B_OK)
				return B_BAD_ADDRESS;
 
			net_route_private route;
			status_t status;
			if ((status = user_copy_address(entry.destination,
					&route.destination)) != B_OK
				|| (status = user_copy_address(entry.mask, &route.mask)) != B_OK
				|| (status = user_copy_address(entry.gateway, &route.gateway))
					!= B_OK)
				return status;
 
			InterfaceAddress* address
				= interface->FirstForFamily(domain->family);
 
			route.mtu = entry.mtu;
			route.flags = entry.flags;
			route.interface_address = address;
 
			if (option == SIOCADDRT)
				status = add_route(domain, &route);
			else
				status = remove_route(domain, &route);
 
			if (address != NULL)
				address->ReleaseReference();
			return status;
		}
	}
	return B_BAD_VALUE;
}
 
 
status_t
add_route(struct net_domain* _domain, const struct net_route* newRoute)
{
	struct net_domain_private* domain = (net_domain_private*)_domain;
 
	TRACE("add route to domain %s: dest %s, mask %s, gw %s, flags %lx\n",
		domain->name,
		AddressString(domain, newRoute->destination
			? newRoute->destination : NULL).Data(),
		AddressString(domain, newRoute->mask ? newRoute->mask : NULL).Data(),
		AddressString(domain, newRoute->gateway
			? newRoute->gateway : NULL).Data(),
		newRoute->flags);
 
	if (domain == NULL || newRoute == NULL
		|| newRoute->interface_address == NULL
		|| ((newRoute->flags & RTF_HOST) != 0 && newRoute->mask != NULL)
		|| ((newRoute->flags & RTF_DEFAULT) == 0
			&& newRoute->destination == NULL)
		|| ((newRoute->flags & RTF_GATEWAY) != 0 && newRoute->gateway == NULL)
		|| !domain->address_module->check_mask(newRoute->mask))
		return B_BAD_VALUE;
 
	RecursiveLocker _(domain->lock);
 
	net_route_private* route = find_route(domain, newRoute);
	if (route != NULL)
		return B_FILE_EXISTS;
 
	route = new (std::nothrow) net_route_private;
	if (route == NULL)
		return B_NO_MEMORY;
 
	if (domain->address_module->copy_address(newRoute->destination,
			&route->destination, (newRoute->flags & RTF_DEFAULT) != 0,
			newRoute->mask) != B_OK
		|| domain->address_module->copy_address(newRoute->mask, &route->mask,
			(newRoute->flags & RTF_DEFAULT) != 0, NULL) != B_OK
		|| domain->address_module->copy_address(newRoute->gateway,
			&route->gateway, false, NULL) != B_OK) {
		delete route;
		return B_NO_MEMORY;
	}
 
	route->flags = newRoute->flags;
	route->interface_address = newRoute->interface_address;
	((InterfaceAddress*)route->interface_address)->AcquireReference();
	route->mtu = 0;
	route->ref_count = 1;
 
	// Insert the route sorted by completeness of its mask
 
	RouteList::Iterator iterator = domain->routes.GetIterator();
	net_route_private* before = NULL;
 
	while ((before = iterator.Next()) != NULL) {
		// if the before mask is less specific than the one of the route,
		// we can insert it before that route.
		if (domain->address_module->first_mask_bit(before->mask)
				> domain->address_module->first_mask_bit(route->mask))
			break;
 
		if ((route->flags & RTF_DEFAULT) != 0
			&& (before->flags & RTF_DEFAULT) != 0) {
			// both routes are equal - let the link speed decide the
			// order
			if (before->interface_address->interface->device->link_speed
					< route->interface_address->interface->device->link_speed)
				break;
		}
	}
 
	domain->routes.Insert(before, route);
	update_route_infos(domain);
 
	return B_OK;
}
 
 
status_t
remove_route(struct net_domain* _domain, const struct net_route* removeRoute)
{
	struct net_domain_private* domain = (net_domain_private*)_domain;
 
	TRACE("remove route from domain %s: dest %s, mask %s, gw %s, flags %lx\n",
		domain->name,
		AddressString(domain, removeRoute->destination
			? removeRoute->destination : NULL).Data(),
		AddressString(domain, removeRoute->mask
			? removeRoute->mask : NULL).Data(),
		AddressString(domain, removeRoute->gateway
			? removeRoute->gateway : NULL).Data(),
		removeRoute->flags);
 
	RecursiveLocker locker(domain->lock);
 
	net_route_private* route = find_route(domain, removeRoute);
	if (route == NULL)
		return B_ENTRY_NOT_FOUND;
 
	domain->routes.Remove(route);
 
	put_route_internal(domain, route);
	update_route_infos(domain);
 
	return B_OK;
}
 
 
status_t
get_route_information(struct net_domain* _domain, void* value, size_t length)
{
	struct net_domain_private* domain = (net_domain_private*)_domain;
 
	if (length < sizeof(route_entry))
		return B_BAD_VALUE;
 
	route_entry entry;
	if (user_memcpy(&entry, value, sizeof(route_entry)) < B_OK)
		return B_BAD_ADDRESS;
 
	sockaddr_storage destination;
	status_t status = user_copy_address(entry.destination, &destination);
	if (status != B_OK)
		return status;
 
	RecursiveLocker locker(domain->lock);
 
	net_route_private* route = find_route(domain, (sockaddr*)&destination);
	if (route == NULL)
		return B_ENTRY_NOT_FOUND;
 
	status = fill_route_entry(&entry, value, length, route);
	if (status != B_OK)
		return status;
 
	return user_memcpy(value, &entry, sizeof(route_entry));
}
 
 
void
invalidate_routes(net_domain* _domain, net_interface* interface)
{
	net_domain_private* domain = (net_domain_private*)_domain;
	RecursiveLocker locker(domain->lock);
 
	TRACE("invalidate_routes(%i, %s)\n", domain->family, interface->name);
 
	RouteList::Iterator iterator = domain->routes.GetIterator();
	while (iterator.HasNext()) {
		net_route* route = iterator.Next();
 
		if (route->interface_address->interface == interface)
			remove_route(domain, route);
	}
}
 
 
void
invalidate_routes(InterfaceAddress* address)
{
	net_domain_private* domain = (net_domain_private*)address->domain;
 
	TRACE("invalidate_routes(%s)\n",
		AddressString(domain, address->local).Data());
 
	RecursiveLocker locker(domain->lock);
 
	RouteList::Iterator iterator = domain->routes.GetIterator();
	while (iterator.HasNext()) {
		net_route* route = iterator.Next();
 
		if (route->interface_address == address)
			remove_route(domain, route);
	}
}
 
 
struct net_route*
get_route(struct net_domain* _domain, const struct sockaddr* address)
{
	struct net_domain_private* domain = (net_domain_private*)_domain;
	RecursiveLocker locker(domain->lock);
 
	return get_route_internal(domain, address);
}
 
 
status_t
get_device_route(struct net_domain* domain, uint32 index, net_route** _route)
{
	Interface* interface = get_interface_for_device(domain, index);
	if (interface == NULL)
		return ENETUNREACH;
 
	net_route_private* route
		= &interface->DomainDatalink(domain->family)->direct_route;
 
	atomic_add(&route->ref_count, 1);
	*_route = route;
 
	interface->ReleaseReference();
	return B_OK;
}
 
 
status_t
get_buffer_route(net_domain* _domain, net_buffer* buffer, net_route** _route)
{
	net_domain_private* domain = (net_domain_private*)_domain;
 
	RecursiveLocker _(domain->lock);
 
	net_route* route = get_route_internal(domain, buffer->destination);
	if (route == NULL)
		return ENETUNREACH;
 
	status_t status = B_OK;
	sockaddr* source = buffer->source;
 
	// TODO: we are quite relaxed in the address checking here
	// as we might proceed with source = INADDR_ANY.
 
	if (route->interface_address != NULL
		&& route->interface_address->local != NULL) {
		status = domain->address_module->update_to(source,
			route->interface_address->local);
	}
 
	if (status != B_OK)
		put_route_internal(domain, route);
	else
		*_route = route;
 
	return status;
}
 
 
void
put_route(struct net_domain* _domain, net_route* route)
{
	struct net_domain_private* domain = (net_domain_private*)_domain;
	if (domain == NULL || route == NULL)
		return;
 
	RecursiveLocker locker(domain->lock);
 
	put_route_internal(domain, (net_route*)route);
}
 
 
status_t
register_route_info(struct net_domain* _domain, struct net_route_info* info)
{
	struct net_domain_private* domain = (net_domain_private*)_domain;
	RecursiveLocker locker(domain->lock);
 
	domain->route_infos.Add(info);
	info->route = get_route_internal(domain, &info->address);
 
	return B_OK;
}
 
 
status_t
unregister_route_info(struct net_domain* _domain, struct net_route_info* info)
{
	struct net_domain_private* domain = (net_domain_private*)_domain;
	RecursiveLocker locker(domain->lock);
 
	domain->route_infos.Remove(info);
	if (info->route != NULL)
		put_route_internal(domain, info->route);
 
	return B_OK;
}
 
 
status_t
update_route_info(struct net_domain* _domain, struct net_route_info* info)
{
	struct net_domain_private* domain = (net_domain_private*)_domain;
	RecursiveLocker locker(domain->lock);
 
	put_route_internal(domain, info->route);
	info->route = get_route_internal(domain, &info->address);
	return B_OK;
}
 

V730 Not all members of a class are initialized inside the constructor. Consider inspecting: ref_count.