#!/usr/bin/env python
###############
### MODULES ###
###############
import csv
import numpy as np
import spatialmath as sm
###############
### CLASSES ###
###############
[docs]class CSVROSDataRow():
"""A helper class to store a single row of data contained within a CSV file created by ROSData tools.
"""
[docs] def __init__(self, data: list, fields: list) -> None:
"""Initialises a CSVROSDataRow object where the class attributes will those contained within the fields argument. For the pose data functions to be accesible the fields argument must contain the strings ['pos_x', 'pos_y', 'pos_z', 'quat_w', 'quat_x', 'quat_y', 'quat_z']. Numeric data will be stored as floats, all other data will be stored as strings except for the string 'none' which will be stored as None.
Args:
data (list): the values for the provided fields
fields (list): the names for each field provided within the data
Raises:
ValueError: if the length of the fields and data arguments are not equal.
"""
# Check fields and data are same length
if len(fields) != len(data):
raise ValueError("The number of fields must be equal to size of the data")
# Attempt to convert data to float
for idx in range(len(data)):
val = data[idx]
try:
# attempt to convert to float
val = float(val)
except ValueError:
if val.lower() == 'none':
# if was the string none, then convert to None type
val = None
# set val
data[idx] = val
# Set class variables (attributes)
for idx, field in enumerate(fields):
setattr(self, field, data[idx])
[docs] def get_pose(self) -> sm.SE3:
"""Will retrieve the pose as a Spatial Maths SE3 object. The pose data must be stored under the headers ['pos_x', 'pos_y', 'pos_z', 'quat_w', 'quat_x', 'quat_y', 'quat_z'].
Returns:
sm.SE3: the pose data as a SE object or None if the pose data does not exist.
"""
data = self.get_pose_data()
if data is None:
return None
# Conver to spatialmath.SE3 object
se3 = sm.SE3(data[:3])
se3.A[:3, :3] = sm.base.q2r(data[3:])
if not isinstance(se3.A, np.ndarray):
return None
# return se3
return se3
[docs] def get_pose_data(self) -> list:
"""Gets the pose data as individual float values. The pose data must be stored under the headers ['pos_x', 'pos_y', 'pos_z', 'quat_w', 'quat_x', 'quat_y', 'quat_z'].
Returns:
list: the pose data as a list of floats. Order will be ['pos_x', 'pos_y', 'pos_z', 'quat_w', 'quat_x', 'quat_y', 'quat_z'].
"""
data = [getattr(self, x) for x in ['pos_x', 'pos_y', 'pos_z', 'quat_w', 'quat_x', 'quat_y', 'quat_z']]
if None in data or len(data) != 7:
return None
return data
[docs]class CSVROSData():
"""A utility class to provide easy accessibility to a CSV file created by ROSData tools.
"""
[docs] def __init__(self, csvfile: str) -> None:
"""Initialises a CSVROSData object where the class attributes, referred to as fields, will be the headers of the CSV file. For the pose data functions to be accesible the CSV file must contain the headers ['pos_x', 'pos_y', 'pos_z', 'quat_w', 'quat_x', 'quat_y', 'quat_z']. Numeric data will be stored as floats, all other data will be stored as strings except for the string 'none' which will be stored as None.
Args:
csvfile (str): the path to the ROSData CSV file
"""
# Class Variables
self._data = []
self._fields = []
# Read CSV file
with open(csvfile, newline='') as f:
csvreader = csv.reader(f, delimiter=',')
for idx, row in enumerate(csvreader):
if idx == 0:
self._fields = row
else:
self._data.append(CSVROSDataRow(row, self._fields))
f.close()
[docs] def pose_data_exists(self) -> bool:
"""Returns true if the CSV ROSData file contained pose information.
Returns:
bool: true if pose information is available.
"""
if self._data[0].get_pose_data():
return True
return False
[docs] def get_pose(self, index : int) -> sm.SE3:
"""Gets the transform (spatialmath.SE3 object) for the given index.
Args:
index (int): the index for the desired transform
Returns:
spatialmath.SE3: returns a spatialmath.SE3 object with the given transform, or None if a transform does not exist for this index.
"""
return self._data[index].get_pose()
[docs] def get_pose_data(self, index : int) -> list:
"""Gets the transform (spatialmath.SE3 object) for the given index.
Args:
index (int): the index for the desired transform
Returns:
spatialmath.SE3: returns a spatialmath.SE3 object with the given transform, or None if a transform does not exist for this index.
"""
return self._data[index].get_pose_data()
[docs] def field_exists(self, field : str) -> bool:
"""Checks to see if a field exists within this CSV data
Args:
field (str): the field
Returns:
bool: true if the field exists
"""
if field in self._fields:
return True
return False
[docs] def get_data(self, indices=None, fields=None):
"""Gets the entire data for a specific index, or the value for a field for a index, or
gets the field for the entire data (e.g., all timestamps).
Examples:
| # return data for index 0
| data = csvrosdata_obj.get_data(0)
| # return all pos_x data
| data = csvrosdata_obj.get_data('pos_x')
| # return all fields for multiple indices
| data = csvrosdata_obj.get_data([0, 2])
| # return all data for a set of fields
| data = csvrosdata_obj.get_data(['pos_x', 'pos_z'])
| # return multiple fields for a specified index or a set of indices
| data = csvrosdata_obj.get_data(0, ['pos_x', 'pos_z'])
| data = csvrosdata_obj.get_data([0, 2], ['pos_x', 'pos_z'])
Args:
indices (int, str, list): the index or indices to be retrieved
fields (optional, int or str): the field or fields to be retrieved
Raises:
ValueError: if too many arguments are provided
Returns:
variable: either the data (list) for a given index, the value for a given index/field or the data (list) for a field across all indices.
"""
# Argument check and conversion to correct format
if indices is None:
indices = []
elif isinstance(indices, (np.integer, np.float, int, float)):
indices = [int(indices)] # change to int and convert to list
elif isinstance(indices, list) and all(isinstance(x, (np.integer, np.float, int, float)) for x in indices):
indices = [int(x) for x in indices]
else:
raise ValueError("The indices argument must be a integer, float or list of integers or floats.")
if fields is None:
fields = []
elif isinstance(fields, str):
fields = [fields]
elif isinstance(fields, list) and all(isinstance(x, str) for x in fields):
pass # don't need to do anything
else:
raise ValueError("The fields argument must be a string, or list of strings.")
if len(indices) != 0 and len(fields) != 0:
retval = []
for x in indices:
if len(fields) == 1:
retval.append(getattr(self._data[x], fields[0]))
else:
retval.append([getattr(self._data[x], y) for y in fields])
elif len(indices) != 0:
retval = [self._data[x] for x in indices]
elif len(fields) != 0:
retval = []
for x in self._data:
if len(fields) == 1:
retval.append(getattr(x, fields[0]))
else:
retval.append([getattr(x, y) for y in fields])
# only return the element if single item in list
if len(retval) == 1:
return retval[0]
return retval
[docs] def __getitem__(self, item):
"""Can be used as a shorthand for the get_data method.
Examples:
| # equivalent to csvrosdata_obj.get_data(indices=0)
| csvrosdata_obj[0]
| # equivalent to csvrosdata_obj.get_data(fields='timestamp')
| csvrosdata_obj['timestamp']
| # equivalent to csvrosdata_obj.get_data(indices=0, fields='timestamp')
| csvrosdata_obj[0, 'timestamp']
"""
if isinstance(item, tuple):
return self.get_data(*item)
elif isinstance(item, (int, np.integer, float, np.float)):
return self.get_data(indices=item)
elif isinstance(item, list) and all(isinstance(x, (np.integer, np.float, int, float)) for x in item):
return self.get_data(indices=item)
elif isinstance(item, str):
return self.get_data(fields=item)
elif isinstance(item, list) and all(isinstance(x, str) for x in item):
return self.get_data(fields=item)
else:
raise ValueError("Unknown argument type")
[docs] def __len__(self) -> int:
"""returns the length of the data
Returns:
int: the length of the data
"""
return len(self._data)
########################
### PUBLIC FUNCTIONS ###
########################
#########################
### PRIVATE FUNCTIONS ###
#########################