threads: add storage api, based on flow storage

pull/12117/head
Jason Ish 5 months ago committed by Victor Julien
parent a6fc37c90a
commit fa230efccb

@ -438,6 +438,7 @@ noinst_HEADERS = \
suricata-common.h \
suricata.h \
suricata-plugin.h \
thread-storage.h \
threads-debug.h \
threads.h \
threads-profile.h \
@ -990,6 +991,7 @@ libsuricata_c_a_SOURCES = \
stream-tcp-sack.c \
stream-tcp-util.c \
suricata.c \
thread-storage.c \
threads.c \
tm-modules.c \
tmqh-flow.c \

@ -0,0 +1,212 @@
/* Copyright (C) 2024 Open Information Security Foundation
*
* You can copy, redistribute or modify this Program under the terms of
* the GNU General Public License version 2 as published by the Free
* Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* version 2 along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
* 02110-1301, USA.
*/
#include "suricata-common.h"
#include "thread-storage.h"
#include "util-storage.h"
#include "util-unittest.h"
const StorageEnum storage_type = STORAGE_THREAD;
unsigned int ThreadStorageSize(void)
{
return StorageGetSize(storage_type);
}
void *ThreadGetStorageById(const ThreadVars *tv, ThreadStorageId id)
{
return StorageGetById(tv->storage, storage_type, id.id);
}
int ThreadSetStorageById(ThreadVars *tv, ThreadStorageId id, void *ptr)
{
return StorageSetById(tv->storage, storage_type, id.id, ptr);
}
void *ThreadAllocStorageById(ThreadVars *tv, ThreadStorageId id)
{
return StorageAllocByIdPrealloc(tv->storage, storage_type, id.id);
}
void ThreadFreeStorageById(ThreadVars *tv, ThreadStorageId id)
{
StorageFreeById(tv->storage, storage_type, id.id);
}
void ThreadFreeStorage(ThreadVars *tv)
{
if (ThreadStorageSize() > 0)
StorageFreeAll(tv->storage, storage_type);
}
ThreadStorageId ThreadStorageRegister(const char *name, const unsigned int size,
void *(*Alloc)(unsigned int), void (*Free)(void *))
{
int id = StorageRegister(storage_type, name, size, Alloc, Free);
ThreadStorageId tsi = { .id = id };
return tsi;
}
#ifdef UNITTESTS
static void *StorageTestAlloc(unsigned int size)
{
return SCCalloc(1, size);
}
static void StorageTestFree(void *x)
{
SCFree(x);
}
static int ThreadStorageTest01(void)
{
StorageInit();
ThreadStorageId id1 = ThreadStorageRegister("test", 8, StorageTestAlloc, StorageTestFree);
FAIL_IF(id1.id < 0);
ThreadStorageId id2 = ThreadStorageRegister("variable", 24, StorageTestAlloc, StorageTestFree);
FAIL_IF(id2.id < 0);
ThreadStorageId id3 =
ThreadStorageRegister("store", sizeof(void *), StorageTestAlloc, StorageTestFree);
FAIL_IF(id3.id < 0);
FAIL_IF(StorageFinalize() < 0);
ThreadVars *tv = SCCalloc(1, sizeof(ThreadVars) + ThreadStorageSize());
FAIL_IF_NULL(tv);
void *ptr = ThreadGetStorageById(tv, id1);
FAIL_IF_NOT_NULL(ptr);
ptr = ThreadGetStorageById(tv, id2);
FAIL_IF_NOT_NULL(ptr);
ptr = ThreadGetStorageById(tv, id3);
FAIL_IF_NOT_NULL(ptr);
void *ptr1a = ThreadAllocStorageById(tv, id1);
FAIL_IF_NULL(ptr1a);
void *ptr2a = ThreadAllocStorageById(tv, id2);
FAIL_IF_NULL(ptr2a);
void *ptr3a = ThreadAllocStorageById(tv, id3);
FAIL_IF_NULL(ptr3a);
void *ptr1b = ThreadGetStorageById(tv, id1);
FAIL_IF(ptr1a != ptr1b);
void *ptr2b = ThreadGetStorageById(tv, id2);
FAIL_IF(ptr2a != ptr2b);
void *ptr3b = ThreadGetStorageById(tv, id3);
FAIL_IF(ptr3a != ptr3b);
ThreadFreeStorage(tv);
StorageCleanup();
SCFree(tv);
PASS;
}
static int ThreadStorageTest02(void)
{
StorageInit();
ThreadStorageId id1 = ThreadStorageRegister("test", sizeof(void *), NULL, StorageTestFree);
FAIL_IF(id1.id < 0);
FAIL_IF(StorageFinalize() < 0);
ThreadVars *tv = SCCalloc(1, sizeof(ThreadVars) + ThreadStorageSize());
FAIL_IF_NULL(tv);
void *ptr = ThreadGetStorageById(tv, id1);
FAIL_IF_NOT_NULL(ptr);
void *ptr1a = SCMalloc(128);
FAIL_IF_NULL(ptr1a);
ThreadSetStorageById(tv, id1, ptr1a);
void *ptr1b = ThreadGetStorageById(tv, id1);
FAIL_IF(ptr1a != ptr1b);
ThreadFreeStorage(tv);
StorageCleanup();
PASS;
}
static int ThreadStorageTest03(void)
{
StorageInit();
ThreadStorageId id1 = ThreadStorageRegister("test1", sizeof(void *), NULL, StorageTestFree);
FAIL_IF(id1.id < 0);
ThreadStorageId id2 = ThreadStorageRegister("test2", sizeof(void *), NULL, StorageTestFree);
FAIL_IF(id2.id < 0);
ThreadStorageId id3 = ThreadStorageRegister("test3", 32, StorageTestAlloc, StorageTestFree);
FAIL_IF(id3.id < 0);
FAIL_IF(StorageFinalize() < 0);
ThreadVars *tv = SCCalloc(1, sizeof(ThreadVars) + ThreadStorageSize());
FAIL_IF_NULL(tv);
void *ptr = ThreadGetStorageById(tv, id1);
FAIL_IF_NOT_NULL(ptr);
void *ptr1a = SCMalloc(128);
FAIL_IF_NULL(ptr1a);
ThreadSetStorageById(tv, id1, ptr1a);
void *ptr2a = SCMalloc(256);
FAIL_IF_NULL(ptr2a);
ThreadSetStorageById(tv, id2, ptr2a);
void *ptr3a = ThreadAllocStorageById(tv, id3);
FAIL_IF_NULL(ptr3a);
void *ptr1b = ThreadGetStorageById(tv, id1);
FAIL_IF(ptr1a != ptr1b);
void *ptr2b = ThreadGetStorageById(tv, id2);
FAIL_IF(ptr2a != ptr2b);
void *ptr3b = ThreadGetStorageById(tv, id3);
FAIL_IF(ptr3a != ptr3b);
ThreadFreeStorage(tv);
StorageCleanup();
PASS;
}
#endif
void RegisterThreadStorageTests(void)
{
#ifdef UNITTESTS
UtRegisterTest("ThreadStorageTest01", ThreadStorageTest01);
UtRegisterTest("ThreadStorageTest02", ThreadStorageTest02);
UtRegisterTest("ThreadStorageTest03", ThreadStorageTest03);
#endif
}

@ -0,0 +1,45 @@
/* Copyright (C) 2024 Open Information Security Foundation
*
* You can copy, redistribute or modify this Program under the terms of
* the GNU General Public License version 2 as published by the Free
* Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* version 2 along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
* 02110-1301, USA.
*/
/**
* Thread wrapper around storage API.
*/
#ifndef SURICATA_THREAD_STORAGE_H
#define SURICATA_THREAD_STORAGE_H
#include "threadvars.h"
typedef struct ThreadStorageId {
int id;
} ThreadStorageId;
unsigned int ThreadStorageSize(void);
void *ThreadGetStorageById(const ThreadVars *tv, ThreadStorageId id);
int ThreadSetStorageById(ThreadVars *tv, ThreadStorageId id, void *ptr);
void *ThreadAllocStorageById(ThreadVars *tv, ThreadStorageId id);
void ThreadFreeStorageById(ThreadVars *tv, ThreadStorageId id);
void ThreadFreeStorage(ThreadVars *tv);
void RegisterThreadStorageTests(void);
ThreadStorageId ThreadStorageRegister(const char *name, const unsigned int size,
void *(*Alloc)(unsigned int), void (*Free)(void *));
#endif /* SURICATA_THREAD_STORAGE_H */

@ -25,6 +25,7 @@
*/
#include "suricata-common.h"
#include "thread-storage.h"
#include "util-unittest.h"
#include "util-debug.h"
#include "threads.h"
@ -149,5 +150,6 @@ void ThreadMacrosRegisterTests(void)
UtRegisterTest("ThreadMacrosTest03RWLocks", ThreadMacrosTest03RWLocks);
UtRegisterTest("ThreadMacrosTest04RWLocks", ThreadMacrosTest04RWLocks);
// UtRegisterTest("ThreadMacrosTest05RWLocks", ThreadMacrosTest05RWLocks);
RegisterThreadStorageTests();
#endif /* UNIT TESTS */
}

@ -28,6 +28,7 @@
#include "counters.h"
#include "packet-queue.h"
#include "util-atomic.h"
#include "util-storage.h"
struct TmSlot_;
@ -135,6 +136,7 @@ typedef struct ThreadVars_ {
struct FlowQueue_ *flow_queue;
bool break_loop;
Storage storage[];
} ThreadVars;
/** Thread setup flags: */

@ -30,6 +30,7 @@
#include "stream.h"
#include "runmodes.h"
#include "threadvars.h"
#include "thread-storage.h"
#include "tm-queues.h"
#include "tm-queuehandlers.h"
#include "tm-threads.h"
@ -919,7 +920,7 @@ ThreadVars *TmThreadCreate(const char *name, const char *inq_name, const char *i
SCLogDebug("creating thread \"%s\"...", name);
/* XXX create separate function for this: allocate a thread container */
tv = SCCalloc(1, sizeof(ThreadVars));
tv = SCCalloc(1, sizeof(ThreadVars) + ThreadStorageSize());
if (unlikely(tv == NULL))
goto error;
@ -1577,6 +1578,8 @@ static void TmThreadFree(ThreadVars *tv)
SCLogDebug("Freeing thread '%s'.", tv->name);
ThreadFreeStorage(tv);
if (tv->flow_queue) {
BUG_ON(tv->flow_queue->qlen != 0);
SCFree(tv->flow_queue);

@ -59,6 +59,8 @@ static const char *StoragePrintType(StorageEnum type)
return "ippair";
case STORAGE_DEVICE:
return "livedevice";
case STORAGE_THREAD:
return "thread";
case STORAGE_MAX:
return "max";
}

@ -31,6 +31,7 @@ typedef enum StorageEnum_ {
STORAGE_FLOW,
STORAGE_IPPAIR,
STORAGE_DEVICE,
STORAGE_THREAD,
STORAGE_MAX,
} StorageEnum;

Loading…
Cancel
Save